diff --git a/shell/AIShell.Abstraction/NamedPipe.cs b/shell/AIShell.Abstraction/NamedPipe.cs index 6041803b..624b12ed 100644 --- a/shell/AIShell.Abstraction/NamedPipe.cs +++ b/shell/AIShell.Abstraction/NamedPipe.cs @@ -35,6 +35,32 @@ public enum MessageType : int PostCode = 4, } +/// +/// Context types that can be requested by AIShell from the connected PowerShell session. +/// +public enum ContextType : int +{ + /// + /// Ask for the current working directory of the shell. + /// + CurrentLocation = 0, + + /// + /// Ask for the command history of the shell session. + /// + CommandHistory = 1, + + /// + /// Ask for the content of the terminal window. + /// + TerminalContent = 2, + + /// + /// Ask for the environment variables of the shell session. + /// + EnvironmentVariables = 3, +} + /// /// Base class for all pipe messages. /// @@ -108,12 +134,24 @@ public AskConnectionMessage(string pipeName) /// public sealed class AskContextMessage : PipeMessage { + /// + /// Gets the type of context information requested. + /// + public ContextType ContextType { get; } + + /// + /// Gets the argument value associated with the current context query operation. + /// + public string[] Arguments { get; } + /// /// Creates an instance of . /// - public AskContextMessage() + public AskContextMessage(ContextType contextType, string[] arguments = null) : base(MessageType.AskContext) { + ContextType = contextType; + Arguments = arguments ?? null; } } @@ -125,21 +163,20 @@ public sealed class PostContextMessage : PipeMessage /// /// Represents a none instance to be used when the shell has no context information to return. /// - public static readonly PostContextMessage None = new([]); + public static readonly PostContextMessage None = new(contextInfo: null); /// - /// Gets the command history. + /// Gets the information of the requested context. /// - public List CommandHistory { get; } + public string ContextInfo { get; } /// /// Creates an instance of . /// - public PostContextMessage(List commandHistory) + public PostContextMessage(string contextInfo) : base(MessageType.PostContext) { - ArgumentNullException.ThrowIfNull(commandHistory); - CommandHistory = commandHistory; + ContextInfo = contextInfo; } } diff --git a/shell/AIShell.Integration/AIShell.psm1 b/shell/AIShell.Integration/AIShell.psm1 index c3b70a78..42910723 100644 --- a/shell/AIShell.Integration/AIShell.psm1 +++ b/shell/AIShell.Integration/AIShell.psm1 @@ -13,4 +13,4 @@ if ($null -eq $runspace) { } ## Create the channel singleton when loading the module. -$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, [Microsoft.PowerShell.PSConsoleReadLine]) +$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, $ExecutionContext, [Microsoft.PowerShell.PSConsoleReadLine]) diff --git a/shell/AIShell.Integration/Channel.cs b/shell/AIShell.Integration/Channel.cs index 2f05261f..7561167f 100644 --- a/shell/AIShell.Integration/Channel.cs +++ b/shell/AIShell.Integration/Channel.cs @@ -1,9 +1,13 @@ -using System.Diagnostics; +using System.Collections.ObjectModel; +using System.Diagnostics; using System.Reflection; using System.Text; using System.Management.Automation; +using System.Management.Automation.Host; using System.Management.Automation.Runspaces; using AIShell.Abstraction; +using Microsoft.PowerShell.Commands; +using System.Text.Json; namespace AIShell.Integration; @@ -15,13 +19,16 @@ public class Channel : IDisposable private readonly string _shellPipeName; private readonly Type _psrlType; private readonly Runspace _runspace; + private readonly EngineIntrinsics _intrinsics; private readonly MethodInfo _psrlInsert, _psrlRevertLine, _psrlAcceptLine; private readonly FieldInfo _psrlHandleResizing, _psrlReadLineReady; private readonly object _psrlSingleton; private readonly ManualResetEvent _connSetupWaitHandler; private readonly Predictor _predictor; private readonly ScriptBlock _onIdleAction; + private readonly List _commandHistory; + private PathInfo _currentLocation; private ShellClientPipe _clientPipe; private ShellServerPipe _serverPipe; private bool? _setupSuccess; @@ -29,14 +36,18 @@ public class Channel : IDisposable private Thread _serverThread; private CodePostData _pendingPostCodeData; - private Channel(Runspace runspace, Type psConsoleReadLineType) + private Channel(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType) { ArgumentNullException.ThrowIfNull(runspace); ArgumentNullException.ThrowIfNull(psConsoleReadLineType); _runspace = runspace; + _intrinsics = intrinsics; _psrlType = psConsoleReadLineType; _connSetupWaitHandler = new ManualResetEvent(false); + _currentLocation = _intrinsics.SessionState.Path.CurrentLocation; + _runspace.AvailabilityChanged += RunspaceAvailableAction; + _intrinsics.InvokeCommand.LocationChangedAction += LocationChangedAction; _shellPipeName = new StringBuilder(MaxNamedPipeNameSize) .Append("pwsh_aish.") @@ -57,13 +68,14 @@ private Channel(Runspace runspace, Type psConsoleReadLineType) _psrlReadLineReady = _psrlType.GetField("_readLineReady", fieldFlags); _psrlHandleResizing = _psrlType.GetField("_handlePotentialResizing", fieldFlags); + _commandHistory = []; _predictor = new Predictor(); _onIdleAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdleHandler()"); } - public static Channel CreateSingleton(Runspace runspace, Type psConsoleReadLineType) + public static Channel CreateSingleton(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType) { - return Singleton ??= new Channel(runspace, psConsoleReadLineType); + return Singleton ??= new Channel(runspace, intrinsics, psConsoleReadLineType); } public static Channel Singleton { get; private set; } @@ -127,6 +139,95 @@ private async void ThreadProc() await _serverPipe.StartProcessingAsync(ConnectionTimeout, CancellationToken.None); } + private void LocationChangedAction(object sender, LocationChangedEventArgs e) + { + _currentLocation = e.NewPath; + } + + private void RunspaceAvailableAction(object sender, RunspaceAvailabilityEventArgs e) + { + if (sender is null || e.RunspaceAvailability is not RunspaceAvailability.Available) + { + return; + } + + // It's safe to get states of the PowerShell Runspace now because it's available and this event + // is handled synchronously. + // We may want to invoke command or script here, and we have to unregister ourself before doing + // that, because the invocation would change the availability of the Runspace, which will cause + // the 'AvailabilityChanged' to be fired again and re-enter our handler. + // We register ourself back after we are done with the processing. + var pwshRunspace = (Runspace)sender; + pwshRunspace.AvailabilityChanged -= RunspaceAvailableAction; + + try + { + using var ps = PowerShell.Create(); + ps.Runspace = pwshRunspace; + + var results = ps + .AddCommand("Get-History") + .AddParameter("Count", 5) + .InvokeAndCleanup(); + + if (results.Count is 0 || + (_commandHistory.Count > 0 && _commandHistory[^1].Id == results[^1].Id)) + { + // No command history yet, or no change since the last update. + return; + } + + lock (_commandHistory) + { + _commandHistory.Clear(); + _commandHistory.AddRange(results); + } + } + finally + { + pwshRunspace.AvailabilityChanged += RunspaceAvailableAction; + } + } + + private string CaptureScreen() + { + if (!OperatingSystem.IsWindows()) + { + return null; + } + + try + { + PSHostRawUserInterface rawUI = _intrinsics.Host.UI.RawUI; + Coordinates start = new(0, 0), end = rawUI.CursorPosition; + end.X = rawUI.BufferSize.Width - 1; + + BufferCell[,] content = rawUI.GetBufferContents(new Rectangle(start, end)); + StringBuilder line = new(), buffer = new(); + + int rows = content.GetLength(0); + int columns = content.GetLength(1); + + for (int row = 0; row < rows; row++) + { + line.Clear(); + for (int column = 0; column < columns; column++) + { + line.Append(content[row, column].Character); + } + + line.TrimEnd(); + buffer.Append(line).Append('\n'); + } + + return buffer.Length is 0 ? string.Empty : buffer.ToString(); + } + catch + { + return null; + } + } + internal void PostQuery(PostQueryMessage message) { ThrowIfNotConnected(); @@ -138,6 +239,8 @@ public void Dispose() Reset(); _connSetupWaitHandler.Dispose(); _predictor.Unregister(); + _runspace.AvailabilityChanged -= RunspaceAvailableAction; + _intrinsics.InvokeCommand.LocationChangedAction -= LocationChangedAction; GC.SuppressFinalize(this); } @@ -257,8 +360,76 @@ private void OnPostCode(PostCodeMessage postCodeMessage) private PostContextMessage OnAskContext(AskContextMessage askContextMessage) { - // Not implemented yet. - return null; + const string RedactedValue = "******"; + + ContextType type = askContextMessage.ContextType; + string[] arguments = askContextMessage.Arguments; + + string contextInfo; + switch (type) + { + case ContextType.CurrentLocation: + contextInfo = JsonSerializer.Serialize( + new { Provider = _currentLocation.Provider.Name, _currentLocation.Path }); + break; + + case ContextType.CommandHistory: + lock (_commandHistory) + { + contextInfo = JsonSerializer.Serialize( + _commandHistory.Select(o => new { o.Id, o.CommandLine })); + } + break; + + case ContextType.TerminalContent: + contextInfo = CaptureScreen(); + break; + + case ContextType.EnvironmentVariables: + if (arguments is { Length: > 0 }) + { + var varsCopy = new Dictionary(); + foreach (string name in arguments) + { + if (!string.IsNullOrEmpty(name)) + { + varsCopy.Add(name, Environment.GetEnvironmentVariable(name) is string value + ? EnvVarMayBeSensitive(name) ? RedactedValue : value + : $"[env variable '{arguments}' is undefined]"); + } + } + + contextInfo = varsCopy.Count > 0 + ? JsonSerializer.Serialize(varsCopy) + : "The specified environment variable names are invalid"; + } + else + { + var vars = Environment.GetEnvironmentVariables(); + var varsCopy = new Dictionary(); + + foreach (string key in vars.Keys) + { + varsCopy.Add(key, EnvVarMayBeSensitive(key) ? RedactedValue : (string)vars[key]); + } + + contextInfo = JsonSerializer.Serialize(varsCopy); + } + break; + + default: + throw new InvalidDataException($"Unknown context type '{type}'"); + } + + return new PostContextMessage(contextInfo); + + static bool EnvVarMayBeSensitive(string key) + { + return key.Contains("key", StringComparison.OrdinalIgnoreCase) || + key.Contains("token", StringComparison.OrdinalIgnoreCase) || + key.Contains("pass", StringComparison.OrdinalIgnoreCase) || + key.Contains("secret", StringComparison.OrdinalIgnoreCase); + } } private void OnAskConnection(ShellClientPipe clientPipe, Exception exception) @@ -334,3 +505,39 @@ public void Dispose() } internal record CodePostData(string CodeToInsert, List PredictionCandidates); + +internal static class ExtensionMethods +{ + internal static Collection InvokeAndCleanup(this PowerShell ps) + { + var results = ps.Invoke(); + ps.Commands.Clear(); + + return results; + } + + internal static void InvokeAndCleanup(this PowerShell ps) + { + ps.Invoke(); + ps.Commands.Clear(); + } + + internal static void TrimEnd(this StringBuilder sb) + { + // end will point to the first non-trimmed character on the right. + int end = sb.Length - 1; + for (; end >= 0; end--) + { + if (!char.IsWhiteSpace(sb[end])) + { + break; + } + } + + int index = end + 1; + if (index < sb.Length) + { + sb.Remove(index, sb.Length - index); + } + } +} diff --git a/shell/AIShell.Integration/Commands/InvokeAishCommand.cs b/shell/AIShell.Integration/Commands/InvokeAishCommand.cs index 9c1d8f70..e4b23411 100644 --- a/shell/AIShell.Integration/Commands/InvokeAishCommand.cs +++ b/shell/AIShell.Integration/Commands/InvokeAishCommand.cs @@ -17,8 +17,8 @@ public class InvokeAIShellCommand : PSCmdlet /// /// Sets and gets the query to be sent to AIShell /// - [Parameter(Mandatory = true, ValueFromRemainingArguments = true, ParameterSetName = DefaultSet)] - [Parameter(Mandatory = true, ValueFromRemainingArguments = true, ParameterSetName = ClipboardSet)] + [Parameter(Position = 0, ParameterSetName = DefaultSet)] + [Parameter(Position = 0, ParameterSetName = ClipboardSet)] public string[] Query { get; set; } /// @@ -88,6 +88,26 @@ protected override void EndProcessing() message = "/exit"; break; default: + if (Query is not null) + { + message = string.Join(' ', Query); + } + else + { + Host.UI.Write("Query: "); + message = Host.UI.ReadLine(); + } + + if (string.IsNullOrEmpty(message)) + { + ThrowTerminatingError( + new ErrorRecord( + new ArgumentException("A query message is required."), + "QueryIsMissing", + ErrorCategory.InvalidArgument, + targetObject: null)); + } + Collection results = null; if (_contextObjects is not null) { @@ -107,7 +127,6 @@ protected override void EndProcessing() } context = results?.Count > 0 ? results[0] : null; - message = string.Join(' ', Query); break; } diff --git a/shell/AIShell.Kernel/Command/McpCommand.cs b/shell/AIShell.Kernel/Command/McpCommand.cs index 58f452cc..1deac2a3 100644 --- a/shell/AIShell.Kernel/Command/McpCommand.cs +++ b/shell/AIShell.Kernel/Command/McpCommand.cs @@ -28,8 +28,9 @@ private void ShowMCPData() { var shell = (Shell)Shell; var host = shell.Host; + var mcpManager = shell.McpManager; - if (shell.McpManager.McpServers.Count is 0) + if (mcpManager.McpServers.Count is 0 && mcpManager.BuiltInTools is null) { host.WriteErrorLine("No MCP server is available."); return; diff --git a/shell/AIShell.Kernel/Host.cs b/shell/AIShell.Kernel/Host.cs index 472fe223..7199f4cd 100644 --- a/shell/AIShell.Kernel/Host.cs +++ b/shell/AIShell.Kernel/Host.cs @@ -662,6 +662,20 @@ internal void RenderMcpServersAndTools(McpManager mcpManager) } } + if (mcpManager.BuiltInTools is { Count: > 0 }) + { + if (toolTable.Rows is { Count: > 0 }) + { + toolTable.AddEmptyRow(); + } + + toolTable.AddRow($"[olive underline]{McpManager.BuiltInServerName}[/]", "[green]\u2713 Ready[/]", string.Empty); + foreach (var item in mcpManager.BuiltInTools) + { + toolTable.AddRow(string.Empty, item.Key.EscapeMarkup(), item.Value.Description.EscapeMarkup()); + } + } + if (readyServers is not null) { foreach (var (name, status, info) in readyServers) diff --git a/shell/AIShell.Kernel/MCP/BuiltInTool.cs b/shell/AIShell.Kernel/MCP/BuiltInTool.cs new file mode 100644 index 00000000..3c2da224 --- /dev/null +++ b/shell/AIShell.Kernel/MCP/BuiltInTool.cs @@ -0,0 +1,427 @@ +using AIShell.Abstraction; +using Microsoft.Extensions.AI; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices.ObjectiveC; +using System.Text.Json; +using System.Xml.Linq; + +namespace AIShell.Kernel.Mcp; + +internal class BuiltInTool : AIFunction +{ + private enum ToolType : int + { + get_current_location = 0, + get_command_history = 1, + get_terminal_content = 2, + get_environment_variables = 3, + copy_text_to_clipboard = 4, + post_code_to_terminal = 5, + run_command_in_terminal = 6, + get_terminal_output = 7, + NumberOfBuiltInTools = 8 + }; + + private static readonly string[] s_toolDescription = + [ + // get_current_location + "Get the current location of the connected PowerShell session, including the provider name (e.g., `FileSystem`, `Certificate`) and the path (e.g., `C:\\`, `cert:\\`).", + + // get_command_history + "Get up to 5 of the most recent commands executed in the connected PowerShell session.", + + // get_terminal_content + "Get all output currently displayed in the terminal window of the connected PowerShell session.", + + //get_environment_variables + "Get environment variables and their values from the connected PowerShell session. Values of potentially sensitive variables are redacted.", + + // copy_text_to_clipboard + "Copy the provided text or code to the system clipboard, making it available for pasting elsewhere.", + + // post_code_to_terminal + "Insert code into the prompt of the connected PowerShell session without executing it. The user can review and choose to run it manually by pressing Enter.", + + // run_command_in_terminal + """ + This tool allows you to execute shell commands in a persistent PowerShell session, preserving environment variables, working directory, and other context across multiple commands. + + Command Execution: + - Supports chaining with `&&` or `;` (e.g., npm install && npm start). + - Supports multi-line commands + + Directory Management: + - Use absolute paths to avoid navigation issues. + + Program Execution: + - Supports running PowerShell commands and scripts. + - Supports Python, Node.js, and other executables. + - Install dependencies via pip, npm, etc. + + Background Processes: + - For long-running tasks (e.g., servers), set `isBackground=true`. + - Returns a terminal ID for checking status and runtime later. + + Important Notes: + - If the command may produce excessively large output, use head or tail to reduce the output. + - If a command may use a pager, you must add something to disable it. For example, you can use `git --no-pager`. Otherwise you should add something like ` | cat`. Examples: git, less, man, etc. + """, + + // get_terminal_output + "Get the output of a command previous started with `run_command_in_terminal`" + ]; + + private static readonly string[] s_toolSchema = + [ + // get_current_location + """ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // get_command_history + """ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // get_terminal_content + """ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // get_environment_variables + """ + { + "type": "object", + "properties": { + "names": { + "type": "array", + "items": { + "type": "string" + }, + "default": null, + "description": "Environment variable names to get values for. When no name is specified, returns all environment variables." + } + }, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // copy_text_to_clipboard + """ + { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Text or code to be copied to the system clipboard." + } + }, + "required": [ + "content" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // post_code_to_terminal + """ + { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Command or code snippet to be inserted to the terminal's prompt." + } + }, + "required": [ + "command" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // run_command_in_terminal + """ + { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Command to run in the connected PowerShell session." + }, + "explanation": { + "type": "string", + "description": "A one-sentence description of what the command does. This will be shown to the user before the command is run." + }, + "isBackground": { + "type": "boolean", + "description": "Whether the command starts a background process. If true, the command will run in the background and you will not see the output. If false, the tool call will block on the command finishing, and then you will get the output. Examples of backgrond processes: building in watch mode, starting a server. You can check the output of a backgrond process later on by using get_terminal_output." + } + }, + "required": [ + "command", + "explanation", + "isBackground" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """, + + // get_terminal_output + """ + { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The ID of the command to get the output from. This is the ID returned by the `run_command_in_terminal` tool." + } + }, + "required": [ + "id" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + """ + ]; + + private readonly string _fullName; + private readonly string _toolName; + private readonly ToolType _toolType; + private readonly string _description; + private readonly JsonElement _jsonSchema; + private readonly Shell _shell; + + private BuiltInTool(ToolType toolType, string description, JsonElement schema, Shell shell) + { + _toolType = toolType; + _shell = shell; + _toolName = toolType.ToString(); + _description = description; + _jsonSchema = schema; + + _fullName = $"{McpManager.BuiltInServerName}{McpManager.ServerToolSeparator}{_toolName}"; + } + + /// + /// The original tool name without the server name prefix. + /// + internal string OriginalName => _toolName; + + /// + /// The fully qualified name of the tool in the form of '.' + /// + public override string Name => _fullName; + + /// + public override string Description => _description; + + /// + public override JsonElement JsonSchema => _jsonSchema; + + /// + /// Overrides the base method with the call to . + /// + protected override async ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + var response = await CallAsync(arguments, cancellationToken).ConfigureAwait(false); + if (response is PostContextMessage postContextMsg) + { + return postContextMsg.ContextInfo; + } + + return response is null ? "Success: Function completed." : JsonSerializer.SerializeToElement(response); + } + + /// + /// Invokes the built-in tool. + /// + /// + /// An optional dictionary of arguments to pass to the tool. + /// Each key represents a parameter name, and its associated value represents the argument value. + /// + /// + /// The cancellation token to monitor for cancellation requests. + /// The default is . + /// + /// The user rejected the tool call. + internal async Task CallAsync( + IReadOnlyDictionary arguments = null, + CancellationToken cancellationToken = default) + { + AskContextMessage contextRequest = _toolType switch + { + ToolType.get_current_location => new(ContextType.CurrentLocation), + ToolType.get_command_history => new(ContextType.CommandHistory), + ToolType.get_terminal_content => new(ContextType.TerminalContent), + ToolType.get_environment_variables => new( + ContextType.EnvironmentVariables, + TryGetArgumentValue(arguments, "names", out string[] names) + ? names is { Length: > 0 } ? names : null + : null), + _ => null + }; + + bool succeeded = false; + PostContextMessage response = null; + + if (contextRequest is not null) + { + succeeded = true; + response = await _shell.Host.RunWithSpinnerAsync( + async () => await _shell.Channel.AskContext(contextRequest, cancellationToken), + status: $"Running '{_toolName}'", + spinnerKind: SpinnerKind.Processing); + } + else if (_toolType is ToolType.copy_text_to_clipboard) + { + TryGetArgumentValue(arguments, "content", out string content); + + if (string.IsNullOrEmpty(content)) + { + throw new ArgumentException("The 'content' argument is required for the 'copy_text_to_clipboard' tool."); + } + + succeeded = true; + Clipboard.SetText(content); + } + else if (_toolType is ToolType.post_code_to_terminal) + { + TryGetArgumentValue(arguments, "command", out string command); + + if (string.IsNullOrEmpty(command)) + { + throw new ArgumentException("The 'command' argument is required for the 'post_code_to_terminal' tool."); + } + + succeeded = true; + _shell.Channel.PostCode(new PostCodeMessage([command])); + } + + if (succeeded) + { + // Notify the user about this tool call. + _shell.Host.MarkupLine($"\n [green]\u2713[/] Ran '{_toolName}'"); + // Signal any active stream reander about the output + FancyStreamRender.ConsoleUpdated(); + + return response; + } + + throw new NotSupportedException($"Tool type '{_toolType}' is not yet supported."); + } + + private static bool TryGetArgumentValue(IReadOnlyDictionary arguments, string argName, out T value) + { + if (arguments is null || !arguments.TryGetValue(argName, out object argValue)) + { + value = default; + return false; + } + + if (argValue is T tValue) + { + value = tValue; + return true; + } + + if (argValue is JsonElement json) + { + Type tType = typeof(T); + JsonValueKind kind = json.ValueKind; + + if (tType == typeof(string)) + { + if (kind is JsonValueKind.String) + { + object stringValue = json.GetString(); + value = (T)stringValue; + return true; + } + + value = default; + return kind is JsonValueKind.Null; + } + + if (tType == typeof(string[])) + { + if (kind is JsonValueKind.Array) + { + object stringArray = json.EnumerateArray().Select(e => e.GetString()).ToArray(); + value = (T)stringArray; + return true; + } + + value = default; + return kind is JsonValueKind.Null; + } + + if (tType == typeof(bool) && kind is JsonValueKind.True or JsonValueKind.False) + { + value = (T)(object)json.GetBoolean(); + return true; + } + + if (tType == typeof(int) && kind is JsonValueKind.Number) + { + value = (T)(object)json.GetInt32(); + return true; + } + } + + value = default; + return false; + } + + /// + /// Gets the list of built-in tools available in AIShell. + /// + internal static Dictionary GetBuiltInTools(Shell shell) + { + ArgumentNullException.ThrowIfNull(shell); + + // We don't have the 'run_command' and 'get_terminal_output' tools yet. Will use 'ToolType.NumberOfBuiltInTools' when all tools are ready. + int toolCount = (int)ToolType.run_command_in_terminal; + Debug.Assert(s_toolDescription.Length == (int)ToolType.NumberOfBuiltInTools, "Number of tool descriptions doesn't match the number of tools."); + Debug.Assert(s_toolSchema.Length == (int)ToolType.NumberOfBuiltInTools, "Number of tool schemas doesn't match the number of tools."); + + if (shell.Channel is null || !shell.Channel.Connected) + { + return null; + } + + Dictionary tools = new(StringComparer.OrdinalIgnoreCase); + for (int i = 0; i < toolCount; i++) + { + ToolType toolType = (ToolType)i; + string description = s_toolDescription[i]; + JsonElement schema = JsonSerializer.Deserialize(s_toolSchema[i]); + BuiltInTool tool = new(toolType, description, schema, shell); + + tools.Add(tool.OriginalName, tool); + } + + return tools; + } +} diff --git a/shell/AIShell.Kernel/MCP/McpConfig.cs b/shell/AIShell.Kernel/MCP/McpConfig.cs index a8453eea..d192327a 100644 --- a/shell/AIShell.Kernel/MCP/McpConfig.cs +++ b/shell/AIShell.Kernel/MCP/McpConfig.cs @@ -188,6 +188,7 @@ internal enum McpType PropertyNameCaseInsensitive = true, ReadCommentHandling = JsonCommentHandling.Skip, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + NumberHandling = JsonNumberHandling.AllowReadingFromString, UseStringEnumConverter = true)] [JsonSerializable(typeof(McpConfig))] [JsonSerializable(typeof(McpServerConfig))] diff --git a/shell/AIShell.Kernel/MCP/McpManager.cs b/shell/AIShell.Kernel/MCP/McpManager.cs index b70575e8..1eef9d3a 100644 --- a/shell/AIShell.Kernel/MCP/McpManager.cs +++ b/shell/AIShell.Kernel/MCP/McpManager.cs @@ -1,6 +1,8 @@ -using Microsoft.Extensions.AI; +using AIShell.Abstraction; +using Microsoft.Extensions.AI; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; +using System.Threading; namespace AIShell.Kernel.Mcp; @@ -12,7 +14,10 @@ internal class McpManager private readonly TaskCompletionSource _parseMcpJsonTaskSource; private McpConfig _mcpConfig; + private Dictionary _builtInTools; + internal const string BuiltInServerName = "AIShell"; + internal const string ServerToolSeparator = "___"; internal Task ParseMcpJsonTask => _parseMcpJsonTaskSource.Task; internal Dictionary McpServers @@ -24,6 +29,15 @@ internal Dictionary McpServers } } + internal Dictionary BuiltInTools + { + get + { + _initTask.Wait(); + return _builtInTools; + } + } + internal McpManager(Shell shell) { _context = new(shell); @@ -45,6 +59,8 @@ private void Initialize() _parseMcpJsonTaskSource.SetException(e); } + _builtInTools = BuiltInTool.GetBuiltInTools(_context.Shell); + if (_mcpConfig is null) { return; @@ -65,6 +81,11 @@ internal async Task> ListAvailableTools() await _initTask; List tools = null; + if (_builtInTools.Count > 0) + { + (tools ??= []).AddRange(_builtInTools.Values); + } + foreach (var (name, server) in _mcpServers) { if (server.IsOperational) @@ -93,7 +114,7 @@ internal async Task CallToolAsync( string serverName = null, toolName = null; string functionName = functionCall.Name; - int dotIndex = functionName.IndexOf(McpTool.ServerToolSeparator); + int dotIndex = functionName.IndexOf(ServerToolSeparator); if (dotIndex > 0) { serverName = functionName[..dotIndex]; @@ -102,15 +123,7 @@ internal async Task CallToolAsync( await _initTask; - McpTool tool = null; - if (!string.IsNullOrEmpty(serverName) - && !string.IsNullOrEmpty(toolName) - && _mcpServers.TryGetValue(serverName, out McpServer server)) - { - await server.WaitForInitAsync(cancellationToken); - server.Tools.TryGetValue(toolName, out tool); - } - + AIFunction tool = await ResolveToolAsync(serverName, toolName, cancellationToken); if (tool is null) { return new FunctionResultContent( @@ -122,11 +135,18 @@ internal async Task CallToolAsync( try { - CallToolResponse response = await tool.CallAsync( - new AIFunctionArguments(arguments: functionCall.Arguments), - cancellationToken: cancellationToken); - - resultContent.Result = (object)response ?? "Success: Function completed."; + var args = new AIFunctionArguments(arguments: functionCall.Arguments); + if (tool is McpTool mcpTool) + { + CallToolResponse response = await mcpTool.CallAsync(args, cancellationToken: cancellationToken); + resultContent.Result = (object)response ?? "Success: Function completed."; + } + else + { + var builtInTool = (BuiltInTool)tool; + PipeMessage response = await builtInTool.CallAsync(args, cancellationToken); + resultContent.Result = (object)response ?? "Success: Function completed."; + } } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { @@ -142,6 +162,28 @@ internal async Task CallToolAsync( return resultContent; } + + private async Task ResolveToolAsync(string serverName, string toolName, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(serverName) || string.IsNullOrEmpty(toolName)) + { + return null; + } + + if (BuiltInServerName.Equals(serverName, StringComparison.OrdinalIgnoreCase)) + { + return _builtInTools.TryGetValue(toolName, out BuiltInTool builtInTool) ? builtInTool : null; + } + + McpTool mcpTool = null; + if (_mcpServers.TryGetValue(serverName, out McpServer server)) + { + await server.WaitForInitAsync(cancellationToken); + server.Tools.TryGetValue(toolName, out mcpTool); + } + + return mcpTool; + } } internal class McpServerInitContext diff --git a/shell/AIShell.Kernel/MCP/McpTool.cs b/shell/AIShell.Kernel/MCP/McpTool.cs index 10771b01..97142485 100644 --- a/shell/AIShell.Kernel/MCP/McpTool.cs +++ b/shell/AIShell.Kernel/MCP/McpTool.cs @@ -12,21 +12,23 @@ namespace AIShell.Kernel.Mcp; /// internal class McpTool : AIFunction { + internal static readonly string[] UserChoices = ["Continue", "Cancel"]; + private readonly string _fullName; private readonly string _serverName; private readonly Host _host; private readonly McpClientTool _clientTool; - private readonly string[] _userChoices; - - internal const string ServerToolSeparator = "___"; internal McpTool(string serverName, McpClientTool clientTool, Host host) { + ArgumentException.ThrowIfNullOrEmpty(serverName); + ArgumentNullException.ThrowIfNull(clientTool); + ArgumentNullException.ThrowIfNull(host); + _host = host; _clientTool = clientTool; - _fullName = $"{serverName}{ServerToolSeparator}{clientTool.Name}"; + _fullName = $"{serverName}{McpManager.ServerToolSeparator}{clientTool.Name}"; _serverName = serverName; - _userChoices = ["Continue", "Cancel"]; } /// @@ -117,7 +119,7 @@ internal async ValueTask CallAsync( const string title = "\n\u26A0 MCP servers or malicious converstaion content may attempt to misuse 'AIShell' through the installed tools. Please carefully review any requested actions to decide if you want to proceed."; string choice = await _host.PromptForSelectionAsync( title: title, - choices: _userChoices, + choices: UserChoices, cancellationToken: cancellationToken); if (choice is "Cancel")