diff --git a/shell/AIShell.Abstraction/AIShell.Abstraction.csproj b/shell/AIShell.Abstraction/AIShell.Abstraction.csproj index d6ed2e77..1eed2edf 100644 --- a/shell/AIShell.Abstraction/AIShell.Abstraction.csproj +++ b/shell/AIShell.Abstraction/AIShell.Abstraction.csproj @@ -7,5 +7,6 @@ + diff --git a/shell/AIShell.Abstraction/IShell.cs b/shell/AIShell.Abstraction/IShell.cs index 8b1ab7dd..c0c04867 100644 --- a/shell/AIShell.Abstraction/IShell.cs +++ b/shell/AIShell.Abstraction/IShell.cs @@ -1,3 +1,5 @@ +using Microsoft.Extensions.AI; + namespace AIShell.Abstraction; /// @@ -27,6 +29,26 @@ public interface IShell /// A list of code blocks or null if there is no code block. List ExtractCodeBlocks(string text, out List sourceInfos); + /// + /// Get available instances for LLM to use. + /// + /// + Task> GetAIFunctions(); + + /// + /// Call an AI function. + /// + /// A instance representing the function call request. + /// Whether or not to capture the exception thrown from calling the tool. + /// Whether or not to include the exception message to the message of the call result. + /// The cancellation token to cancel the call. + /// + Task CallAIFunction( + FunctionCallContent functionCall, + bool captureException, + bool includeDetailedErrors, + CancellationToken cancellationToken); + // TODO: // - methods to run code: python, command-line, powershell, node-js. // - methods to communicate with shell client. diff --git a/shell/AIShell.App/AIShell.App.csproj b/shell/AIShell.App/AIShell.App.csproj index dd8f9592..86969f52 100644 --- a/shell/AIShell.App/AIShell.App.csproj +++ b/shell/AIShell.App/AIShell.App.csproj @@ -12,7 +12,6 @@ - diff --git a/shell/AIShell.Kernel/AIShell.Kernel.csproj b/shell/AIShell.Kernel/AIShell.Kernel.csproj index 442b0dbf..5a58665f 100644 --- a/shell/AIShell.Kernel/AIShell.Kernel.csproj +++ b/shell/AIShell.Kernel/AIShell.Kernel.csproj @@ -6,8 +6,8 @@ - - + + diff --git a/shell/AIShell.Kernel/Command/CommandRunner.cs b/shell/AIShell.Kernel/Command/CommandRunner.cs index 515d430f..1043b532 100644 --- a/shell/AIShell.Kernel/Command/CommandRunner.cs +++ b/shell/AIShell.Kernel/Command/CommandRunner.cs @@ -35,6 +35,7 @@ internal CommandRunner(Shell shell) new RefreshCommand(), new RetryCommand(), new HelpCommand(), + new McpCommand(), //new RenderCommand(), }; diff --git a/shell/AIShell.Kernel/Command/McpCommand.cs b/shell/AIShell.Kernel/Command/McpCommand.cs new file mode 100644 index 00000000..58f452cc --- /dev/null +++ b/shell/AIShell.Kernel/Command/McpCommand.cs @@ -0,0 +1,40 @@ +using System.CommandLine; +using AIShell.Abstraction; + +namespace AIShell.Kernel.Commands; + +internal sealed class McpCommand : CommandBase +{ + public McpCommand() + : base("mcp", "Command for managing MCP servers and tools.") + { + this.SetHandler(ShowMCPData); + + //var start = new Command("start", "Start an MCP server."); + //var stop = new Command("stop", "Stop an MCP server."); + //var server = new Argument( + // name: "server", + // getDefaultValue: () => null, + // description: "Name of an MCP server.").AddCompletions(AgentCompleter); + + //start.AddArgument(server); + //start.SetHandler(StartMcpServer, server); + + //stop.AddArgument(server); + //stop.SetHandler(StopMcpServer, server); + } + + private void ShowMCPData() + { + var shell = (Shell)Shell; + var host = shell.Host; + + if (shell.McpManager.McpServers.Count is 0) + { + host.WriteErrorLine("No MCP server is available."); + return; + } + + host.RenderMcpServersAndTools(shell.McpManager); + } +} diff --git a/shell/AIShell.Kernel/Host.cs b/shell/AIShell.Kernel/Host.cs index 27e94544..472fe223 100644 --- a/shell/AIShell.Kernel/Host.cs +++ b/shell/AIShell.Kernel/Host.cs @@ -2,9 +2,12 @@ using System.Text; using AIShell.Abstraction; +using AIShell.Kernel.Mcp; using Markdig.Helpers; using Microsoft.PowerShell; using Spectre.Console; +using Spectre.Console.Json; +using Spectre.Console.Rendering; namespace AIShell.Kernel; @@ -175,7 +178,7 @@ public void RenderFullResponse(string response) /// public void RenderTable(IList sources) { - RequireStdoutOrStderr(operation: "render table"); + RequireStdout(operation: "render table"); ArgumentNullException.ThrowIfNull(sources); if (sources.Count is 0) @@ -198,7 +201,7 @@ public void RenderTable(IList sources) /// public void RenderTable(IList sources, IList> elements) { - RequireStdoutOrStderr(operation: "render table"); + RequireStdout(operation: "render table"); ArgumentNullException.ThrowIfNull(sources); ArgumentNullException.ThrowIfNull(elements); @@ -240,7 +243,7 @@ public void RenderTable(IList sources, IList> elements) /// public void RenderList(T source) { - RequireStdoutOrStderr(operation: "render list"); + RequireStdout(operation: "render list"); ArgumentNullException.ThrowIfNull(source); if (source is IDictionary dict) @@ -271,7 +274,7 @@ public void RenderList(T source) /// public void RenderList(T source, IList> elements) { - RequireStdoutOrStderr(operation: "render list"); + RequireStdout(operation: "render list"); ArgumentNullException.ThrowIfNull(source); ArgumentNullException.ThrowIfNull(elements); @@ -313,7 +316,7 @@ public void RenderList(T source, IList> elements) public void RenderDivider(string text, DividerAlignment alignment) { ArgumentException.ThrowIfNullOrEmpty(text); - RequireStdoutOrStderr(operation: "render divider"); + RequireStdout(operation: "render divider"); if (!text.Contains("[/]")) { @@ -550,15 +553,134 @@ public string PromptForArgument(ArgumentInfo argInfo, bool printCaption) internal void RenderReferenceText(string header, string content) { RequireStdoutOrStderr(operation: "Render reference"); + IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console; var panel = new Panel($"\n[italic]{content.EscapeMarkup()}[/]\n") .RoundedBorder() .BorderColor(Color.DarkCyan) .Header($"[orange3 on italic] {header.Trim()} [/]"); - AnsiConsole.WriteLine(); - AnsiConsole.Write(panel); - AnsiConsole.WriteLine(); + ansiConsole.WriteLine(); + ansiConsole.Write(panel); + ansiConsole.WriteLine(); + } + + /// + /// Render the MCP tool call request. + /// + /// The MCP tool. + /// The arguments in JSON form to be sent for the tool call. + internal void RenderToolCallRequest(McpTool tool, string jsonArgs) + { + RequireStdoutOrStderr(operation: "render tool call request"); + IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console; + + bool hasArgs = !string.IsNullOrEmpty(jsonArgs); + IRenderable content = new Markup($""" + + [bold]Run [olive]{tool.OriginalName}[/] from [olive]{tool.ServerName}[/] (MCP server)[/] + + {tool.Description} + + Input:{(hasArgs ? string.Empty : " ")} + """); + + if (hasArgs) + { + var json = new JsonText(jsonArgs) + .MemberColor(Color.Aqua) + .ColonColor(Color.White) + .CommaColor(Color.White) + .StringStyle(Color.Tan); + + content = new Grid() + .AddColumn(new GridColumn()) + .AddRow(content) + .AddRow(json); + } + + var panel = new Panel(content) + .Expand() + .RoundedBorder() + .Header("[green] Tool Call Request [/]") + .BorderColor(Color.Grey); + + ansiConsole.WriteLine(); + ansiConsole.Write(panel); + FancyStreamRender.ConsoleUpdated(); + } + + /// + /// Render a table with information about available MCP servers and tools. + /// + /// The MCP manager instance. + internal void RenderMcpServersAndTools(McpManager mcpManager) + { + RequireStdout(operation: "render MCP servers and tools"); + + var toolTable = new Table() + .LeftAligned() + .SimpleBorder() + .BorderColor(Color.Green); + + toolTable.AddColumn("[green bold]Server[/]"); + toolTable.AddColumn("[green bold]Tool[/]"); + toolTable.AddColumn("[green bold]Description[/]"); + + List<(string name, string status, string info)> readyServers = null, startingServers = null, failedServers = null; + foreach (var (name, server) in mcpManager.McpServers) + { + (int code, string status, string info) = server.IsInitFinished + ? server.Error is null + ? (1, "[green]\u2713 Ready[/]", string.Empty) + : (-1, "[red]\u2717 Failed[/]", $"[red]{server.Error.Message.EscapeMarkup()}[/]") + : (0, "[yellow]\u25CB Starting[/]", string.Empty); + + var list = code switch + { + 1 => readyServers ??= [], + 0 => startingServers ??= [], + _ => failedServers ??= [], + }; + + list.Add((name, status, info)); + } + + if (startingServers is not null) + { + foreach (var (name, status, info) in startingServers) + { + toolTable.AddRow($"[olive underline]{name}[/]", status, info); + } + } + + if (failedServers is not null) + { + foreach (var (name, status, info) in failedServers) + { + toolTable.AddRow($"[olive underline]{name}[/]", status, info); + } + } + + if (readyServers is not null) + { + foreach (var (name, status, info) in readyServers) + { + if (toolTable.Rows is { Count: > 0 }) + { + toolTable.AddEmptyRow(); + } + + var server = mcpManager.McpServers[name]; + toolTable.AddRow($"[olive underline]{name}[/]", status, info); + foreach (var item in server.Tools) + { + toolTable.AddRow(string.Empty, item.Key.EscapeMarkup(), item.Value.Description.EscapeMarkup()); + } + } + } + + AnsiConsole.Write(toolTable); } private static Spinner GetSpinner(SpinnerKind? kind) @@ -583,6 +705,19 @@ private void RequireStdin(string operation) } } + /// + /// Throw exception if standard output is redirected. + /// + /// The intended operation. + /// Throw the exception if stdout is redirected. + private void RequireStdout(string operation) + { + if (_outputRedirected) + { + throw new InvalidOperationException($"Cannot {operation} when the stdout is redirected."); + } + } + /// /// Throw exception if both standard output and error are redirected. /// diff --git a/shell/AIShell.Kernel/MCP/McpConfig.cs b/shell/AIShell.Kernel/MCP/McpConfig.cs new file mode 100644 index 00000000..a8453eea --- /dev/null +++ b/shell/AIShell.Kernel/MCP/McpConfig.cs @@ -0,0 +1,195 @@ +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +namespace AIShell.Kernel.Mcp; + +/// +/// MCP configuration defined in mcp.json. +/// +internal class McpConfig +{ + [JsonPropertyName("servers")] + public Dictionary Servers { get; set; } = []; + + internal static McpConfig Load() + { + McpConfig mcpConfig = null; + FileInfo file = new(Utils.AppMcpFile); + if (file.Exists) + { + using var stream = file.OpenRead(); + mcpConfig = JsonSerializer.Deserialize(stream, McpJsonContext.Default.McpConfig); + mcpConfig.Validate(); + } + + return mcpConfig is { Servers.Count: 0 } ? null : mcpConfig; + } + + /// + /// Post-deserialization validation. + /// + /// + private void Validate() + { + List allErrors = null; + + foreach (var (name, server) in Servers) + { + server.Name = name; + if (Enum.TryParse(server.Type, ignoreCase: true, out McpType mcpType)) + { + server.Transport = mcpType; + } + else + { + (allErrors ??= []).Add($"Server '{name}': 'type' is required and the value should be one of the following: {string.Join(',', Enum.GetNames())}."); + continue; + } + + List curErrs = null; + + if (mcpType is McpType.stdio) + { + bool hasUrlGroup = !string.IsNullOrEmpty(server.Url) || server.Headers is { }; + if (hasUrlGroup) + { + (curErrs ??= []).Add($"'url' and 'headers' fields are invalid for 'stdio' type servers."); + } + + if (string.IsNullOrEmpty(server.Command)) + { + (curErrs ??= []).Add($"'command' is required for 'stdio' type servers."); + } + } + else + { + bool hasCommandGroup = !string.IsNullOrEmpty(server.Command) || server.Args is { } || server.Env is { }; + if (hasCommandGroup) + { + (curErrs ??= []).Add($"'command', 'args', and 'env' fields are invalid for '{mcpType}' type servers."); + } + + if (string.IsNullOrEmpty(server.Url)) + { + (curErrs ??= []).Add($"'url' is required for '{mcpType}' type servers."); + } + else if (!Uri.TryCreate(server.Url, UriKind.Absolute, out Uri uri)) + { + (curErrs ??= []).Add($"the specified value for 'url' is not a valid URI."); + } + else if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) + { + (curErrs ??= []).Add($"'url' is expected to be a 'http' or 'https' resource."); + } + else + { + server.Endpoint = uri; + } + } + + if (curErrs is [string onlyErr]) + { + (allErrors ??= []).Add($"Server '{name}': {onlyErr}"); + } + else if (curErrs is { Count: > 1 }) + { + string prefix = $"Server '{name}':"; + int size = curErrs.Sum(a => a.Length) + curErrs.Count * 5 + prefix.Length; + StringBuilder sb = new(prefix, capacity: size); + + foreach (string element in curErrs) + { + sb.Append($"\n - {element}"); + } + + (allErrors ??= []).Add(sb.ToString()); + } + } + + if (allErrors is { }) + { + string errorMsg = string.Join('\n', allErrors); + throw new InvalidOperationException(errorMsg); + } + } +} + +/// +/// Configuration of a server. +/// +internal class McpServerConfig +{ + [JsonPropertyName("type")] + public string Type { get; set; } + + [JsonPropertyName("command")] + public string Command { get; set; } + + [JsonPropertyName("args")] + public List Args { get; set; } + + [JsonPropertyName("env")] + public Dictionary Env { get; set; } + + [JsonPropertyName("url")] + public string Url { get; set; } + + [JsonPropertyName("headers")] + public Dictionary Headers { get; set; } + + internal string Name { get; set; } + internal Uri Endpoint { get; set; } + internal McpType Transport { get; set; } + + internal IClientTransport ToClientTransport() + { + return Transport switch + { + McpType.stdio => new StdioClientTransport(new() + { + Name = Name, + Command = Command, + Arguments = Args, + EnvironmentVariables = Env, + }), + + _ => new SseClientTransport(new() + { + Name = Name, + Endpoint = Endpoint, + AdditionalHeaders = Headers, + TransportMode = Transport is McpType.sse ? HttpTransportMode.Sse : HttpTransportMode.StreamableHttp, + ConnectionTimeout = TimeSpan.FromSeconds(15), + }) + }; + } +} + +/// +/// MCP transport types. +/// +internal enum McpType +{ + stdio, + sse, + http, +} + +/// +/// Source generation helper for deserializing mcp.json. +/// +[JsonSourceGenerationOptions( + WriteIndented = true, + AllowTrailingCommas = true, + PropertyNameCaseInsensitive = true, + ReadCommentHandling = JsonCommentHandling.Skip, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + UseStringEnumConverter = true)] +[JsonSerializable(typeof(McpConfig))] +[JsonSerializable(typeof(McpServerConfig))] +[JsonSerializable(typeof(CallToolResponse))] +internal partial class McpJsonContext : JsonSerializerContext { } diff --git a/shell/AIShell.Kernel/MCP/McpManager.cs b/shell/AIShell.Kernel/MCP/McpManager.cs new file mode 100644 index 00000000..b70575e8 --- /dev/null +++ b/shell/AIShell.Kernel/MCP/McpManager.cs @@ -0,0 +1,168 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +namespace AIShell.Kernel.Mcp; + +internal class McpManager +{ + private readonly Task _initTask; + private readonly McpServerInitContext _context; + private readonly Dictionary _mcpServers; + private readonly TaskCompletionSource _parseMcpJsonTaskSource; + + private McpConfig _mcpConfig; + + internal Task ParseMcpJsonTask => _parseMcpJsonTaskSource.Task; + + internal Dictionary McpServers + { + get + { + _initTask.Wait(); + return _mcpServers; + } + } + + internal McpManager(Shell shell) + { + _context = new(shell); + _parseMcpJsonTaskSource = new(); + _mcpServers = new(StringComparer.OrdinalIgnoreCase); + + _initTask = Task.Run(Initialize); + } + + private void Initialize() + { + try + { + _mcpConfig = McpConfig.Load(); + _parseMcpJsonTaskSource.SetResult(_mcpConfig); + } + catch (Exception e) + { + _parseMcpJsonTaskSource.SetException(e); + } + + if (_mcpConfig is null) + { + return; + } + + foreach (var (name, config) in _mcpConfig.Servers) + { + _mcpServers.Add(name, new McpServer(config, _context)); + } + } + + /// + /// Lists tools that are available at the time of the call. + /// Servers that are still initializing or failed will be skipped. + /// + internal async Task> ListAvailableTools() + { + await _initTask; + + List tools = null; + foreach (var (name, server) in _mcpServers) + { + if (server.IsOperational) + { + (tools ??= []).AddRange(server.Tools.Values); + } + } + + return tools; + } + + /// + /// Make a tool call using the given function call data. + /// + /// The function call request. + /// Whether or not to capture the exception thrown from calling the tool. + /// Whether or not to include the exception message to the message of the call result. + /// The cancellation token to cancel the call. + /// + internal async Task CallToolAsync( + FunctionCallContent functionCall, + bool captureException = false, + bool includeDetailedErrors = false, + CancellationToken cancellationToken = default) + { + string serverName = null, toolName = null; + + string functionName = functionCall.Name; + int dotIndex = functionName.IndexOf(McpTool.ServerToolSeparator); + if (dotIndex > 0) + { + serverName = functionName[..dotIndex]; + toolName = functionName[(dotIndex + 1)..]; + } + + await _initTask; + + McpTool tool = null; + if (!string.IsNullOrEmpty(serverName) + && !string.IsNullOrEmpty(toolName) + && _mcpServers.TryGetValue(serverName, out McpServer server)) + { + await server.WaitForInitAsync(cancellationToken); + server.Tools.TryGetValue(toolName, out tool); + } + + if (tool is null) + { + return new FunctionResultContent( + functionCall.CallId, + $"Error: Requested function \"{functionName}\" not found."); + } + + FunctionResultContent resultContent = new(functionCall.CallId, result: null); + + try + { + CallToolResponse response = await tool.CallAsync( + new AIFunctionArguments(arguments: functionCall.Arguments), + cancellationToken: cancellationToken); + + resultContent.Result = (object)response ?? "Success: Function completed."; + } + catch (Exception e) when (!cancellationToken.IsCancellationRequested) + { + if (!captureException) + { + throw; + } + + string message = "Error: Function failed."; + resultContent.Exception = e; + resultContent.Result = includeDetailedErrors ? $"{message} Exception: {e.Message}" : message; + } + + return resultContent; + } +} + +internal class McpServerInitContext +{ + /// + /// The throttle limit defines the maximum number of servers that can be initiated concurrently. + /// + private const int ThrottleLimit = 5; + + internal McpServerInitContext(Shell shell) + { + Shell = shell; + ThrottleSemaphore = new SemaphoreSlim(ThrottleLimit, ThrottleLimit); + ClientOptions = new() + { + ClientInfo = new() { Name = "AIShell", Version = shell.Version }, + InitializationTimeout = TimeSpan.FromSeconds(30), + }; + } + + internal Shell Shell { get; } + internal SemaphoreSlim ThrottleSemaphore { get; } + internal McpClientOptions ClientOptions { get; } +} diff --git a/shell/AIShell.Kernel/MCP/McpServer.cs b/shell/AIShell.Kernel/MCP/McpServer.cs new file mode 100644 index 00000000..df43d3f8 --- /dev/null +++ b/shell/AIShell.Kernel/MCP/McpServer.cs @@ -0,0 +1,135 @@ +using ModelContextProtocol.Client; + +namespace AIShell.Kernel.Mcp; + +internal class McpServer : IDisposable +{ + private readonly McpServerConfig _config; + private readonly McpServerInitContext _context; + private readonly Dictionary _tools; + private readonly Task _initTask; + + private string _serverInfo; + private IMcpClient _client; + private Exception _error; + + /// + /// Name of the server declared in mcp.json. + /// + internal string Name => _config.Name; + + /// + /// Gets whether the initialization is done. + /// + internal bool IsInitFinished => _initTask.IsCompleted; + + /// + /// Gets whether the server is operational. + /// + internal bool IsOperational => _initTask.IsCompleted && _error is null; + + /// + /// Full name and version of the server. + /// + internal string ServerInfo + { + get + { + WaitForInit(); + return _serverInfo; + } + } + + /// + /// The client connected to the server. + /// + internal IMcpClient Client + { + get + { + WaitForInit(); + return _client; + } + } + + /// + /// Exposed tools from the server. + /// + internal Dictionary Tools + { + get + { + WaitForInit(); + return _tools; + } + } + + internal Exception Error + { + get + { + WaitForInit(); + return _error; + } + } + + internal McpServer(McpServerConfig config, McpServerInitContext context) + { + _config = config; + _context = context; + _tools = new(StringComparer.OrdinalIgnoreCase); + _initTask = Initialize(); + } + + private async Task Initialize() + { + try + { + await _context.ThrottleSemaphore.WaitAsync(); + + IClientTransport transport = _config.ToClientTransport(); + _client = await McpClientFactory.CreateAsync(transport, _context.ClientOptions); + + var serverInfo = _client.ServerInfo; + // An MCP server may have the name included in the version info. + _serverInfo = serverInfo.Version.Contains(serverInfo.Name, StringComparison.OrdinalIgnoreCase) + ? serverInfo.Version + : $"{serverInfo.Name} {serverInfo.Version}"; + + await foreach (McpClientTool tool in _client.EnumerateToolsAsync()) + { + _tools.TryAdd(tool.Name, new McpTool(Name, tool, _context.Shell.Host)); + } + } + catch (Exception e) + { + _error = e; + _tools.Clear(); + if (_client is { }) + { + await _client.DisposeAsync(); + _client = null; + } + } + finally + { + _context.ThrottleSemaphore.Release(); + } + } + + internal void WaitForInit(CancellationToken cancellationToken = default) + { + _initTask.Wait(cancellationToken); + } + + internal async Task WaitForInitAsync(CancellationToken cancellationToken = default) + { + await _initTask.WaitAsync(cancellationToken); + } + + public void Dispose() + { + _tools.Clear(); + _client?.DisposeAsync().AsTask().Wait(); + } +} diff --git a/shell/AIShell.Kernel/MCP/McpTool.cs b/shell/AIShell.Kernel/MCP/McpTool.cs new file mode 100644 index 00000000..10771b01 --- /dev/null +++ b/shell/AIShell.Kernel/MCP/McpTool.cs @@ -0,0 +1,137 @@ +using System.Text.Json; +using AIShell.Abstraction; +using Microsoft.Extensions.AI; +using ModelContextProtocol; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +namespace AIShell.Kernel.Mcp; + +/// +/// A wrapper class of to make sure the call to the tool always go through the AIShell. +/// +internal class McpTool : AIFunction +{ + private readonly string _fullName; + private readonly string _serverName; + private readonly Host _host; + private readonly McpClientTool _clientTool; + private readonly string[] _userChoices; + + internal const string ServerToolSeparator = "___"; + + internal McpTool(string serverName, McpClientTool clientTool, Host host) + { + _host = host; + _clientTool = clientTool; + _fullName = $"{serverName}{ServerToolSeparator}{clientTool.Name}"; + _serverName = serverName; + _userChoices = ["Continue", "Cancel"]; + } + + /// + /// The server name for this tool. + /// + internal string ServerName => _serverName; + + /// + /// The original tool name without the server name prefix. + /// + internal string OriginalName => _clientTool.Name; + + /// + /// The fully qualified name of the tool in the form of '.' + /// + public override string Name => _fullName; + + /// + public override string Description => _clientTool.Description; + + /// + public override JsonElement JsonSchema => _clientTool.JsonSchema; + + /// + public override JsonSerializerOptions JsonSerializerOptions => _clientTool.JsonSerializerOptions; + + /// + public override IReadOnlyDictionary AdditionalProperties => _clientTool.AdditionalProperties; + + /// + /// Overrides the base method with the call to . The only difference in behavior is we will serialize + /// the resulting such that the returned is a + /// containing the serialized . This method is intended to be used polymorphically via the base + /// class, typically as part of an operation. + /// + protected override async ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + return JsonSerializer.SerializeToElement( + await CallAsync( + arguments, + progress: null, + JsonSerializerOptions, + cancellationToken).ConfigureAwait(false), + McpJsonContext.Default.CallToolResponse); + } + + /// + /// Invokes the tool on the server if the user approves. + /// + /// + /// An optional dictionary of arguments to pass to the tool. + /// Each key represents a parameter name, and its associated value represents the argument value. + /// + /// + /// An optional to have progress notifications reported to it. + /// Setting this to a non- value will result in a progress token being included in the call, + /// and any resulting progress notifications during the operation routed to this instance. + /// + /// + /// The JSON serialization options governing argument serialization. + /// If , the default serialization options will be used. + /// + /// + /// The cancellation token to monitor for cancellation requests. + /// The default is . + /// + /// + /// A task containing the response from the tool execution. The response includes the tool's output content, which may be structured data, text, or an error message. + /// + /// + /// This method wraps the method to add the user interactions for displaying the too call request and prompting the user for approval. + /// + /// The user rejected the tool call. + /// The server could not find the requested tool, or the server encountered an error while processing the request. + internal async ValueTask CallAsync( + IReadOnlyDictionary arguments = null, + IProgress progress = null, + JsonSerializerOptions serializerOptions = null, + CancellationToken cancellationToken = default) + { + // Display the tool call request. + string jsonArgs = arguments is { Count: > 0 } + ? JsonSerializer.Serialize(arguments, serializerOptions ?? JsonSerializerOptions) + : null; + _host.RenderToolCallRequest(this, jsonArgs); + + // Prompt for user's approval to call the tool. + const string title = "\n\u26A0 MCP servers or malicious converstaion content may attempt to misuse 'AIShell' through the installed tools. Please carefully review any requested actions to decide if you want to proceed."; + string choice = await _host.PromptForSelectionAsync( + title: title, + choices: _userChoices, + cancellationToken: cancellationToken); + + if (choice is "Cancel") + { + _host.MarkupLine($"\n [red]\u2717[/] Cancelled '{OriginalName}'"); + throw new OperationCanceledException("The call was rejected by user."); + } + + CallToolResponse response = await _host.RunWithSpinnerAsync( + async () => await _clientTool.CallAsync(arguments, progress, serializerOptions, cancellationToken), + status: $"Running '{OriginalName}'", + spinnerKind: SpinnerKind.Processing); + + _host.MarkupLine($"\n [green]\u2713[/] Ran '{OriginalName}'"); + return response; + } +} diff --git a/shell/AIShell.Kernel/Render/StreamRender.cs b/shell/AIShell.Kernel/Render/StreamRender.cs index 118badab..5ac05063 100644 --- a/shell/AIShell.Kernel/Render/StreamRender.cs +++ b/shell/AIShell.Kernel/Render/StreamRender.cs @@ -38,32 +38,49 @@ internal sealed partial class FancyStreamRender : IStreamRender internal const char ESC = '\x1b'; internal static readonly Regex AnsiRegex = CreateAnsiRegex(); + private static int s_consoleUpdateFlag = 0; + + private readonly int _bufferWidth, _bufferHeight; private readonly MarkdownRender _markdownRender; private readonly StringBuilder _buffer; private readonly CancellationToken _cancellationToken; - private string _currentText; - private int _bufferWidth, _bufferHeight; + private int _localFlag; private Point _initialCursor; + private string _currentText; private string _accumulatedContent; + private List _previousContents; internal FancyStreamRender(MarkdownRender markdownRender, CancellationToken token) { - _currentText = string.Empty; _bufferWidth = Console.BufferWidth; _bufferHeight = Console.BufferHeight; - _initialCursor = new(Console.CursorLeft, Console.CursorTop); - _markdownRender = markdownRender; _buffer = new StringBuilder(); _cancellationToken = token; - _accumulatedContent = string.Empty; + + _localFlag = s_consoleUpdateFlag; + _initialCursor = new(Console.CursorLeft, Console.CursorTop); + _accumulatedContent = _currentText = string.Empty; + _previousContents = null; // Hide the cursor when rendering the streaming response. Console.CursorVisible = false; } - public string AccumulatedContent => _accumulatedContent; + public string AccumulatedContent + { + get + { + if (_previousContents is null) + { + return _accumulatedContent; + } + + _previousContents.Add(_accumulatedContent); + return string.Concat(_previousContents); + } + } public List CodeBlocks { @@ -72,7 +89,7 @@ public List CodeBlocks // Create a new list to return, so as to prevent agents from changing // the list that is used internally by 'CodeBlockVisitor'. var blocks = _markdownRender.GetAllCodeBlocks(); - return blocks is null ? null : new List(blocks); + return blocks is null ? null : [.. blocks]; } } @@ -93,6 +110,22 @@ public void Refresh(string newChunk) // Avoid rendering the new chunk up on cancellation. _cancellationToken.ThrowIfCancellationRequested(); + // The host wrote out something while this stream render is active. + // We need to reset the state of this stream render in this case. + if (_localFlag < s_consoleUpdateFlag) + { + _localFlag = s_consoleUpdateFlag; + _initialCursor = new(Console.CursorLeft, Console.CursorTop); + Console.CursorVisible = false; + + if (_buffer.Length > 0) + { + (_previousContents ??= []).Add(_accumulatedContent); + _accumulatedContent = _currentText = string.Empty; + _buffer.Clear(); + } + } + _buffer.Append(newChunk); _accumulatedContent = _buffer.ToString(); RefreshImpl(_markdownRender.RenderText(_accumulatedContent)); @@ -307,6 +340,23 @@ private Point ConvertOffsetToPoint(Point point, string text, int offset) return new Point(x, y); } + + /// + /// Call this method to report writing to console from outside the stream render. + /// + /// + /// With the MCP tool calls, we may need to render the tool call request while a stream render is active. + /// This method is used to notify an active stream render about updates in console from elsewhere, so the + /// stream render can reset its state and start freshly. + /// Note that, a stream render and the host won't really write to console in parallel. But it is possible + /// that the stream render wrote some output and stoped, and then the host wrote some other output. Since + /// the stream render depends on the initial cursor position to refresh all content when new chunks coming + /// in, it needs to reset its state in such a case, so that it can continue to work correctly afterwards. + /// + internal static void ConsoleUpdated() + { + s_consoleUpdateFlag++; + } } internal struct Point diff --git a/shell/AIShell.Kernel/Setting.cs b/shell/AIShell.Kernel/Setting.cs index 02492e9f..087dd200 100644 --- a/shell/AIShell.Kernel/Setting.cs +++ b/shell/AIShell.Kernel/Setting.cs @@ -19,7 +19,7 @@ internal static Setting Load() try { using var stream = file.OpenRead(); - return JsonSerializer.Deserialize(stream, SourceGenerationContext.Default.Setting); + return JsonSerializer.Deserialize(stream, AppSettingJsonContext.Default.Setting); } catch (Exception e) { @@ -48,4 +48,4 @@ internal static Setting Load() ReadCommentHandling = JsonCommentHandling.Skip, UseStringEnumConverter = true)] [JsonSerializable(typeof(Setting))] -internal partial class SourceGenerationContext : JsonSerializerContext { } +internal partial class AppSettingJsonContext : JsonSerializerContext { } diff --git a/shell/AIShell.Kernel/Shell.cs b/shell/AIShell.Kernel/Shell.cs index c353ee4c..271c19f3 100644 --- a/shell/AIShell.Kernel/Shell.cs +++ b/shell/AIShell.Kernel/Shell.cs @@ -1,7 +1,9 @@ using System.Reflection; -using Microsoft.PowerShell; using AIShell.Abstraction; using AIShell.Kernel.Commands; +using AIShell.Kernel.Mcp; +using Microsoft.Extensions.AI; +using Microsoft.PowerShell; using Spectre.Console; namespace AIShell.Kernel; @@ -14,6 +16,7 @@ internal sealed class Shell : IShell private readonly ShellWrapper _wrapper; private readonly HashSet _textToIgnore; private readonly Setting _setting; + private readonly McpManager _mcpManager; private bool _shouldRefresh; private LLMAgent _activeAgent; @@ -74,12 +77,25 @@ internal sealed class Shell : IShell /// internal LLMAgent ActiveAgent => _activeAgent; + /// + /// Gets the version from the assembly attribute. + /// + internal string Version => _version; + + /// + /// Gets the MCP manager. + /// + internal McpManager McpManager => _mcpManager; + #region IShell implementation IHost IShell.Host => Host; bool IShell.ChannelEstablished => Channel is not null; CancellationToken IShell.CancellationToken => _cancellationSource.Token; List IShell.ExtractCodeBlocks(string text, out List sourceInfos) => Utils.ExtractCodeBlocks(text, out sourceInfos); + async Task> IShell.GetAIFunctions() => await _mcpManager.ListAvailableTools(); + async Task IShell.CallAIFunction(FunctionCallContent functionCall, bool captureException, bool includeDetailedErrors, CancellationToken cancellationToken) + => await _mcpManager.CallToolAsync(functionCall, captureException, includeDetailedErrors, cancellationToken); #endregion IShell implementation @@ -104,6 +120,7 @@ internal Shell(bool interactive, ShellArgs args) _textToIgnore = new HashSet(StringComparer.OrdinalIgnoreCase); _cancellationSource = new CancellationTokenSource(); _version = typeof(Shell).Assembly.GetCustomAttribute().InformationalVersion; + _mcpManager = new(this); Exit = false; Regenerate = false; @@ -156,9 +173,21 @@ internal void ShowLandingPage() _activeAgent.Display(Host, isWrapped ? _wrapper.Description : null); } + try + { + McpConfig config = _mcpManager.ParseMcpJsonTask.GetAwaiter().GetResult(); + if (config is { Servers.Count: > 0 }) + { + Host.MarkupNoteLine($"{config.Servers.Count} MCP server(s) configured. Run {Formatter.Command("/mcp")} for details."); + } + } + catch (Exception e) + { + Host.WriteErrorLine($"Failed to load the 'mcp.json' file:\n{e.Message}\nRun '/mcp config' to open the 'mcp.json' file.\n"); + } + // Write out help. - Host.MarkupLine($"Run {Formatter.Command("/help")} for more instructions.") - .WriteLine(); + Host.MarkupLine($"Run {Formatter.Command("/help")} for more instructions.\n"); } /// diff --git a/shell/AIShell.Kernel/Utility/LoadContext.cs b/shell/AIShell.Kernel/Utility/LoadContext.cs index 3cb70c08..ae00182d 100644 --- a/shell/AIShell.Kernel/Utility/LoadContext.cs +++ b/shell/AIShell.Kernel/Utility/LoadContext.cs @@ -25,7 +25,14 @@ internal AgentAssemblyLoadContext(string name, string dependencyDir) _dependencyDir = dependencyDir; _runtimeLibDir = []; _runtimeNativeDir = []; - _cache = []; + _cache = new() + { + // Contracts exposed from 'AIShell.Abstraction' depend on these assemblies, + // so agents have to depend on the same assemblies from the default ALC. + // Otherwise, the contracts will break due to mis-match type identities. + ["System.CommandLine"] = null, + ["Microsoft.Extensions.AI.Abstractions"] = null, + }; if (OperatingSystem.IsWindows()) { diff --git a/shell/AIShell.Kernel/Utility/Utils.cs b/shell/AIShell.Kernel/Utility/Utils.cs index bc9428f0..d8cde85c 100644 --- a/shell/AIShell.Kernel/Utility/Utils.cs +++ b/shell/AIShell.Kernel/Utility/Utils.cs @@ -46,6 +46,7 @@ internal static class Utils internal static string ConfigHome; internal static string AppCacheDir; internal static string AppConfigFile; + internal static string AppMcpFile; internal static string AgentHome; internal static string AgentConfigHome; @@ -59,6 +60,7 @@ internal static void Setup(string appName) ConfigHome = Path.Combine(locationPath, $".{AppName.Replace(' ', '-')}"); AppCacheDir = Path.Combine(ConfigHome, ".cache"); AppConfigFile = Path.Combine(ConfigHome, "config.json"); + AppMcpFile = Path.Combine(ConfigHome, "mcp.json"); AgentHome = Path.Join(ConfigHome, "agents"); AgentConfigHome = Path.Join(ConfigHome, "agent-config"); diff --git a/shell/Markdown.VT/Markdown.VT.csproj b/shell/Markdown.VT/Markdown.VT.csproj index 12b0e835..fcf0d7f6 100644 --- a/shell/Markdown.VT/Markdown.VT.csproj +++ b/shell/Markdown.VT/Markdown.VT.csproj @@ -6,9 +6,9 @@ - + - + diff --git a/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj b/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj index d03afbaa..3e7464d1 100644 --- a/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj +++ b/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj @@ -17,7 +17,7 @@ - + diff --git a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj index b0c53299..21a674aa 100644 --- a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj +++ b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj @@ -21,11 +21,10 @@ - - - - - + + + + diff --git a/shell/agents/AIShell.OpenAI.Agent/Agent.cs b/shell/agents/AIShell.OpenAI.Agent/Agent.cs index c1a93c76..cbbcab1a 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Agent.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Agent.cs @@ -1,8 +1,7 @@ -using System.ClientModel; using System.Text; using System.Text.Json; using AIShell.Abstraction; -using OpenAI.Chat; +using Microsoft.Extensions.AI; namespace AIShell.OpenAI.Agent; @@ -116,42 +115,86 @@ public async Task ChatAsync(string input, IShell shell) return checkPass; } - IAsyncEnumerator response = await host + IAsyncEnumerator response = await host .RunWithSpinnerAsync( - () => _chatService.GetStreamingChatResponseAsync(input, token) + () => _chatService.GetStreamingChatResponseAsync(input, shell, token) ).ConfigureAwait(false); if (response is not null) { - StreamingChatCompletionUpdate update = null; + int? toolCalls = null; + bool isReasoning = false; + List updates = []; using var streamingRender = host.NewStreamRender(token); try { do { - update = response.Current; - if (update.ContentUpdate.Count > 0) + if (toolCalls is 0) { - streamingRender.Refresh(update.ContentUpdate[0].Text); + toolCalls = null; + } + + ChatResponseUpdate update = response.Current; + updates.Add(update); + + foreach (AIContent content in update.Contents) + { + if (content is TextReasoningContent reason) + { + if (isReasoning) + { + streamingRender.Refresh(reason.Text); + } + else + { + isReasoning = true; + streamingRender.Refresh($"\n{reason.Text}"); + } + + continue; + } + + string message = content switch + { + TextContent text => text.Text ?? string.Empty, + ErrorContent error => error.Message ?? string.Empty, + _ => null + }; + + if (message is null) + { + toolCalls = content switch + { + FunctionCallContent => (toolCalls + 1) ?? 1, + FunctionResultContent => toolCalls - 1, + _ => toolCalls + }; + } + + if (isReasoning) + { + isReasoning = false; + message = $"\n\n\n{message}"; + } + + if (!string.IsNullOrEmpty(message)) + { + streamingRender.Refresh(message); + } } } - while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false)); + while (toolCalls is 0 + ? await host.RunWithSpinnerAsync(() => response.MoveNextAsync().AsTask()).ConfigureAwait(false) + : await response.MoveNextAsync().ConfigureAwait(false)); } catch (OperationCanceledException) { - update = null; + // Ignore cancellation exception. } - if (update is null) - { - _chatService.CalibrateChatHistory(usage: null, response: null); - } - else - { - string responseContent = streamingRender.AccumulatedContent; - _chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent)); - } + _chatService.ChatHistory.AddMessages(updates); } return checkPass; diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs index 2ce1dd9f..790d705a 100644 --- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs +++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs @@ -1,5 +1,3 @@ -using Microsoft.ML.Tokenizers; - namespace AIShell.OpenAI.Agent; internal class ModelInfo @@ -11,7 +9,6 @@ internal class ModelInfo private const string Gpt34Encoding = "cl100k_base"; private static readonly Dictionary s_modelMap; - private static readonly Dictionary> s_encodingMap; // A rough estimate to cover all third-party models. // - most popular models today support 32K+ context length; @@ -37,21 +34,12 @@ static ModelInfo() // Azure naming of the 'gpt-3.5-turbo' models ["gpt-35-turbo"] = new(tokenLimit: 16_385), }; - - // The first call to 'GptEncoding.GetEncoding' is very slow, taking about 2 seconds on my machine. - // We don't immediately need the encodings at the startup, so by getting the values in tasks, - // we don't block the startup and the values will be ready when we really need them. - s_encodingMap = new(StringComparer.OrdinalIgnoreCase) - { - [Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)), - [Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding)) - }; } private ModelInfo(int tokenLimit, string encoding = null, bool reasoning = false) { TokenLimit = tokenLimit; - _encodingName = encoding ?? Gpt34Encoding; + EncodingName = encoding ?? Gpt34Encoding; // For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same. // See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -60,22 +48,11 @@ private ModelInfo(int tokenLimit, string encoding = null, bool reasoning = false Reasoning = reasoning; } - private readonly string _encodingName; - private Tokenizer _gptEncoding; - + internal string EncodingName { get; } internal int TokenLimit { get; } internal int TokensPerMessage { get; } internal int TokensPerName { get; } internal bool Reasoning { get; } - internal Tokenizer Encoding - { - get { - _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value) - ? value.Result - : TiktokenTokenizer.CreateForEncoding(_encodingName); - return _gptEncoding; - } - } /// /// Try resolving the specified model name. diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs index 026a5fde..b366b915 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Service.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs @@ -1,104 +1,47 @@ using System.ClientModel; using System.ClientModel.Primitives; using Azure.AI.OpenAI; -using Azure.Core; using Azure.Identity; -using Microsoft.ML.Tokenizers; +using Microsoft.Extensions.AI; using OpenAI; -using OpenAI.Chat; + +using OpenAIChatClient = OpenAI.Chat.ChatClient; +using ChatMessage = Microsoft.Extensions.AI.ChatMessage; +using AIShell.Abstraction; namespace AIShell.OpenAI.Agent; internal class ChatService { // TODO: Maybe expose this to our model registration? - // We can still use 1000 as the default value. - private const int MaxResponseToken = 2000; + private const int MaxResponseToken = 1000; private readonly string _historyRoot; private readonly List _chatHistory; - private readonly List _chatHistoryTokens; - private readonly ChatCompletionOptions _chatOptions; + private readonly ChatOptions _chatOptions; private GPT _gptToUse; private Settings _settings; - private ChatClient _client; - private int _totalInputToken; + private IChatClient _client; internal ChatService(string historyRoot, Settings settings) { _chatHistory = []; - _chatHistoryTokens = []; _historyRoot = historyRoot; - - _totalInputToken = 0; _settings = settings; - _chatOptions = new ChatCompletionOptions() + _chatOptions = new ChatOptions() { - MaxOutputTokenCount = MaxResponseToken, + MaxOutputTokens = MaxResponseToken, }; } internal List ChatHistory => _chatHistory; - internal void AddResponseToHistory(string response) - { - if (string.IsNullOrEmpty(response)) - { - return; - } - - _chatHistory.Add(ChatMessage.CreateAssistantMessage(response)); - } - internal void RefreshSettings(Settings settings) { _settings = settings; } - /// - /// It's almost impossible to relative-accurately calculate the token counts of all - /// messages, especially when tool calls are involved (tool call definitions and the - /// tool call payloads in AI response). - /// So, I decide to leverage the useage report from AI to track the token count of - /// the chat history. It's also an estimate, but I think more accurate than doing the - /// counting by ourselves. - /// - internal void CalibrateChatHistory(ChatTokenUsage usage, AssistantChatMessage response) - { - if (usage is null) - { - // Response was cancelled and we will remove the last query from history. - int index = _chatHistory.Count - 1; - _chatHistory.RemoveAt(index); - _chatHistoryTokens.RemoveAt(index); - - return; - } - - // Every reply is primed with <|start|>assistant<|message|>, so we subtract 3 from the 'InputTokenCount'. - int promptTokenCount = usage.InputTokenCount - 3; - // 'ReasoningTokenCount' should be 0 for non-o1 models. - int reasoningTokenCount = usage.OutputTokenDetails is null ? 0 : usage.OutputTokenDetails.ReasoningTokenCount; - int responseTokenCount = usage.OutputTokenCount - reasoningTokenCount; - - if (_totalInputToken is 0) - { - // It was the first user message, so instead of adjusting the user message token count, - // we set the token count for system message and tool calls. - _chatHistoryTokens[0] = promptTokenCount - _chatHistoryTokens[^1]; - } - else - { - // Adjust the token count of the user message, as our calculation is an estimate. - _chatHistoryTokens[^1] = promptTokenCount - _totalInputToken; - } - - _chatHistory.Add(response); - _chatHistoryTokens.Add(responseTokenCount); - _totalInputToken = promptTokenCount + responseTokenCount; - } - private void RefreshOpenAIClient() { if (ReferenceEquals(_gptToUse, _settings.Active)) @@ -110,7 +53,6 @@ private void RefreshOpenAIClient() GPT old = _gptToUse; _gptToUse = _settings.Active; _chatHistory.Clear(); - _chatHistoryTokens.Clear(); if (old is not null && old.Type == _gptToUse.Type @@ -124,6 +66,7 @@ private void RefreshOpenAIClient() return; } + OpenAIChatClient client; EndpointType type = _gptToUse.Type; // Reasoning models do not support the temperature setting. _chatOptions.Temperature = _gptToUse.ModelInfo.Reasoning ? null : 0; @@ -154,7 +97,7 @@ private void RefreshOpenAIClient() new ApiKeyCredential(azOpenAIApiKey), clientOptions); - _client = aiClient.GetChatClient(_gptToUse.Deployment); + client = aiClient.GetChatClient(_gptToUse.Deployment); } else { @@ -165,7 +108,7 @@ private void RefreshOpenAIClient() credential, clientOptions); - _client = aiClient.GetChatClient(_gptToUse.Deployment); + client = aiClient.GetChatClient(_gptToUse.Deployment); } } else @@ -179,91 +122,50 @@ private void RefreshOpenAIClient() string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions); - _client = aiClient.GetChatClient(_gptToUse.ModelName); - } - } - - /// - /// Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - /// - private int CountTokenForUserMessage(UserChatMessage message) - { - ModelInfo modelDetail = _gptToUse.ModelInfo; - Tokenizer encoding = modelDetail.Encoding; - - // Tokens per message plus 1 token for the role. - int tokenNumber = modelDetail.TokensPerMessage + 1; - foreach (ChatMessageContentPart part in message.Content) - { - tokenNumber += encoding.CountTokens(part.Text); + client = aiClient.GetChatClient(_gptToUse.ModelName); } - return tokenNumber; + _client = client.AsIChatClient() + .AsBuilder() + .UseFunctionInvocation(configure: c => c.IncludeDetailedErrors = true) + .Build(); } private void PrepareForChat(string input) { + const string Guidelines = """ + ## Tool Use Guidelines + You may have access to external tools. + Before making any tool call, you must first explain the reason for using the tool. Only issue the tool call after providing this explanation. + + ## Other Guidelines + """; + // Refresh the client in case the active model was changed. RefreshOpenAIClient(); if (_chatHistory.Count is 0) { - _chatHistory.Add(ChatMessage.CreateSystemMessage(_gptToUse.SystemPrompt)); - _chatHistoryTokens.Add(0); + string system = $"{Guidelines}\n{_gptToUse.SystemPrompt}"; + _chatHistory.Add(new(ChatRole.System, system)); } - var userMessage = new UserChatMessage(input); - int msgTokenCnt = CountTokenForUserMessage(userMessage); - _chatHistory.Add(userMessage); - _chatHistoryTokens.Add(msgTokenCnt); - - int inputLimit = _gptToUse.ModelInfo.TokenLimit; - // Every reply is primed with <|start|>assistant<|message|>, so adding 3 tokens. - int newTotal = _totalInputToken + msgTokenCnt + 3; - - // Shrink the chat history if we have less than 50 free tokens left (50-token buffer). - while (inputLimit - newTotal < 50) - { - // We remove a round of conversation for every trimming operation. - int userMsgCnt = 0; - List indices = []; - - for (int i = 0; i < _chatHistory.Count; i++) - { - if (_chatHistory[i] is UserChatMessage) - { - if (userMsgCnt is 1) - { - break; - } - - userMsgCnt++; - } - - if (userMsgCnt is 1) - { - indices.Add(i); - } - } - - foreach (int i in indices) - { - newTotal -= _chatHistoryTokens[i]; - } - - _chatHistory.RemoveRange(indices[0], indices.Count); - _chatHistoryTokens.RemoveRange(indices[0], indices.Count); - _totalInputToken = newTotal - msgTokenCnt; - } + _chatHistory.Add(new(ChatRole.User, input)); } - public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken) + public async Task> GetStreamingChatResponseAsync(string input, IShell shell, CancellationToken cancellationToken) { try { PrepareForChat(input); - IAsyncEnumerator enumerator = _client - .CompleteChatStreamingAsync(_chatHistory, _chatOptions, cancellationToken) + var tools = await shell.GetAIFunctions(); + if (tools is { Count: > 0 }) + { + _chatOptions.Tools = [.. tools]; + } + + IAsyncEnumerator enumerator = _client + .GetStreamingResponseAsync(_chatHistory, _chatOptions, cancellationToken) .GetAsyncEnumerator(cancellationToken); return await enumerator