Skip to content

Commit 1acd0e0

Browse files
authored
Add the built-in tool run_command_in_terminal for AI to execute commands in the connected PowerShell session (#398)
Allow AIShell to run command in the connected PowerShell session and collect all output and error. 1. Add `Invoke-AICommand` cmdlet (alias: `airun`) to `AIShell` module. Commands sent from the sidecar AIShell will be executed through this command in the form of `airun { <command> }`. This command is designed to collect all output and error messages as they are displayed in the terminal, while preserving the streaming behavior as expected. 2. Add `RunCommand` and `PostResult` messages to the protocol. 3. Update the `Channel` class in `AIShell` module to support the `OnRunCommand` action. We already support posting command to the PowerShell's prompt, but it turns out not easy to make the command be accepted. On Windows, we have to call `AcceptLine` within an `OnIdle` event handler and it also requires changes to `PSReadLine`. - `AcceptLine` only set a flag in `PSReadLine` to indicate the line was accepted. The flag is checked in `InputLoop`, however, when `PSReadLine` is waiting for input, it's blocked in the `ReadKey` call within `InputLoop`, so even if the flag is set, `InputLoop` won't be able to check the flag until after `ReadKey` call is returned. - I need to change PSReadLine a bit: after it finishes handling the `OnIdle` event, it checks if the `_lineAccepted` flag is set. If it's set, it means `AcceptLine` got called within the `OnIdle` handler, and it throws a `LineAcceptedException` to break out from `ReadKey`. I catch this exception in `InputLoop` to continue with the flag check. - However, a problem with this change is: the "readkey thread" is still blocked on `Console.ReadKey` when the command is returned to PowerShell to execute. On Windows, this could cause minor issues if the command also calls `Console.ReadKey` -- 2 threads calling `Console.ReadKey` in parallel, so it's uncertain which will get the next keystroke input. On macOS and Linux, the problem is way much bigger -- any subsequent writing to the terminal may be blocked, because on Unix platforms, reading cursor position will be blocked if another thread is calling `Console.ReadKey`. - So, this approach can only work on Windows. On macOS, we will have to depend on `iTerm2`, which has a Python API server that allows to send keystrokes to a tab using the Python API, so we could possibly use that for macOS. But Windows Terminal doesn't support that, and thus we will have to use the above approach to accept the command on Windows. - On macOS, if the Python API approach works fine, then we could even consider using it for the `PostCode` action. 4. Add `run_command_in_terminal` and `get_command_output` tools and expose them to agents.
1 parent cd50156 commit 1acd0e0

File tree

10 files changed

+686
-41
lines changed

10 files changed

+686
-41
lines changed

shell/AIShell.Abstraction/NamedPipe.cs

Lines changed: 237 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ public enum MessageType : int
3333
/// A message from AIShell to command-line shell to send code block.
3434
/// </summary>
3535
PostCode = 4,
36+
37+
/// <summary>
38+
/// A message from AIShell to command-line shell to run a command.
39+
/// </summary>
40+
RunCommand = 5,
41+
42+
/// <summary>
43+
/// A message from AIShell to command-line shell to ask for the result of a previous command run.
44+
/// </summary>
45+
AskCommandOutput = 6,
46+
47+
/// <summary>
48+
/// A message from command-line shell to AIShell to post the result of a command.
49+
/// </summary>
50+
PostResult = 7,
3651
}
3752

3853
/// <summary>
@@ -201,6 +216,95 @@ public PostCodeMessage(List<string> codeBlocks)
201216
}
202217
}
203218

219+
/// <summary>
220+
/// Message for <see cref="MessageType.RunCommand"/>.
221+
/// </summary>
222+
public sealed class RunCommandMessage : PipeMessage
223+
{
224+
/// <summary>
225+
/// Gets the command to run.
226+
/// </summary>
227+
public string Command { get; }
228+
229+
/// <summary>
230+
/// Gets whether the command should be run in blocking mode.
231+
/// </summary>
232+
public bool Blocking { get; }
233+
234+
/// <summary>
235+
/// Creates an instance of <see cref="RunCommandMessage"/>.
236+
/// </summary>
237+
public RunCommandMessage(string command, bool blocking)
238+
: base(MessageType.RunCommand)
239+
{
240+
ArgumentException.ThrowIfNullOrEmpty(command);
241+
242+
Command = command;
243+
Blocking = blocking;
244+
}
245+
}
246+
247+
/// <summary>
248+
/// Message for <see cref="MessageType.AskCommandOutput"/>.
249+
/// </summary>
250+
public sealed class AskCommandOutputMessage : PipeMessage
251+
{
252+
/// <summary>
253+
/// Gets the id of the command to retrieve the output for.
254+
/// </summary>
255+
public string CommandId { get; }
256+
257+
/// <summary>
258+
/// Creates an instance of <see cref="AskCommandOutputMessage"/>.
259+
/// </summary>
260+
public AskCommandOutputMessage(string commandId)
261+
: base(MessageType.AskCommandOutput)
262+
{
263+
ArgumentException.ThrowIfNullOrEmpty(commandId);
264+
CommandId = commandId;
265+
}
266+
}
267+
268+
/// <summary>
269+
/// Message for <see cref="MessageType.PostResult"/>.
270+
/// </summary>
271+
public sealed class PostResultMessage : PipeMessage
272+
{
273+
/// <summary>
274+
/// Gets the result of the command for a blocking 'run_command' too call.
275+
/// Or, for a non-blocking call, gets the id for retrieving the result later.
276+
/// </summary>
277+
public string Output { get; }
278+
279+
/// <summary>
280+
/// Gets whether the command execution had any error.
281+
/// i.e. a native command returned a non-zero exit code, or a powershell command threw any errors.
282+
/// </summary>
283+
public bool HadError { get; }
284+
285+
/// <summary>
286+
/// Gets a value indicating whether the operation was canceled by the user.
287+
/// </summary>
288+
public bool UserCancelled { get; }
289+
290+
/// <summary>
291+
/// Gets the internal exception message that is thrown when trying to run the command.
292+
/// </summary>
293+
public string Exception { get; }
294+
295+
/// <summary>
296+
/// Creates an instance of <see cref="PostResultMessage"/>.
297+
/// </summary>
298+
public PostResultMessage(string output, bool hadError, bool userCancelled, string exception)
299+
: base(MessageType.PostResult)
300+
{
301+
Output = output;
302+
HadError = hadError;
303+
UserCancelled = userCancelled;
304+
Exception = exception;
305+
}
306+
}
307+
204308
/// <summary>
205309
/// The base type for common pipe operations.
206310
/// </summary>
@@ -301,7 +405,7 @@ protected async Task<PipeMessage> GetMessageAsync(CancellationToken cancellation
301405
return null;
302406
}
303407

304-
if (type > (int)MessageType.PostCode)
408+
if (type > (int)MessageType.PostResult)
305409
{
306410
_pipeStream.Close();
307411
throw new IOException($"Unknown message type received: {type}. Connection was dropped.");
@@ -344,9 +448,12 @@ private static PipeMessage DeserializePayload(int type, ReadOnlySpan<byte> bytes
344448
{
345449
(int)MessageType.PostQuery => JsonSerializer.Deserialize<PostQueryMessage>(bytes),
346450
(int)MessageType.AskConnection => JsonSerializer.Deserialize<AskConnectionMessage>(bytes),
347-
(int)MessageType.PostContext => JsonSerializer.Deserialize<PostContextMessage>(bytes),
348451
(int)MessageType.AskContext => JsonSerializer.Deserialize<AskContextMessage>(bytes),
452+
(int)MessageType.PostContext => JsonSerializer.Deserialize<PostContextMessage>(bytes),
349453
(int)MessageType.PostCode => JsonSerializer.Deserialize<PostCodeMessage>(bytes),
454+
(int)MessageType.RunCommand => JsonSerializer.Deserialize<RunCommandMessage>(bytes),
455+
(int)MessageType.AskCommandOutput => JsonSerializer.Deserialize<AskCommandOutputMessage>(bytes),
456+
(int)MessageType.PostResult => JsonSerializer.Deserialize<PostResultMessage>(bytes),
350457
_ => throw new NotSupportedException("Unreachable code"),
351458
};
352459
}
@@ -465,6 +572,16 @@ public async Task StartProcessingAsync(int timeout, CancellationToken cancellati
465572
InvokeOnPostCode((PostCodeMessage)message);
466573
break;
467574

575+
case MessageType.RunCommand:
576+
var result = InvokeOnRunCommand((RunCommandMessage)message);
577+
SendMessage(result);
578+
break;
579+
580+
case MessageType.AskCommandOutput:
581+
var output = InvokeOnAskCommandOutput((AskCommandOutputMessage)message);
582+
SendMessage(output);
583+
break;
584+
468585
default:
469586
// Log: unexpected messages ignored.
470587
break;
@@ -537,6 +654,66 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message)
537654
return null;
538655
}
539656

657+
/// <summary>
658+
/// Helper to invoke the <see cref="OnRunCommand"/> event.
659+
/// </summary>
660+
private PostResultMessage InvokeOnRunCommand(RunCommandMessage message)
661+
{
662+
if (OnRunCommand is null)
663+
{
664+
// Log: event handler not set.
665+
return new PostResultMessage(
666+
output: "Command execution is not supported.",
667+
hadError: true,
668+
userCancelled: false,
669+
exception: null);
670+
}
671+
672+
try
673+
{
674+
return OnRunCommand(message);
675+
}
676+
catch (Exception e)
677+
{
678+
// Log: exception when invoking 'OnRunCommand'
679+
return new PostResultMessage(
680+
output: "Failed to execute the command due to an internal error.",
681+
hadError: true,
682+
userCancelled: false,
683+
exception: e.Message);
684+
}
685+
}
686+
687+
/// <summary>
688+
/// Helper to invoke the <see cref="OnAskCommandOutput"/> event.
689+
/// </summary>
690+
private PostResultMessage InvokeOnAskCommandOutput(AskCommandOutputMessage message)
691+
{
692+
if (OnAskCommandOutput is null)
693+
{
694+
// Log: event handler not set.
695+
return new PostResultMessage(
696+
output: "Retrieving command output is not supported.",
697+
hadError: true,
698+
userCancelled: false,
699+
exception: null);
700+
}
701+
702+
try
703+
{
704+
return OnAskCommandOutput(message);
705+
}
706+
catch (Exception e)
707+
{
708+
// Log: exception when invoking 'OnAskCommandOutput'
709+
return new PostResultMessage(
710+
output: "Failed to retrieve the command output due to an internal error.",
711+
hadError: true,
712+
userCancelled: false,
713+
exception: e.Message);
714+
}
715+
}
716+
540717
/// <summary>
541718
/// Event for handling the <see cref="MessageType.PostCode"/> message.
542719
/// </summary>
@@ -551,6 +728,16 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message)
551728
/// Event for handling the <see cref="MessageType.AskContext"/> message.
552729
/// </summary>
553730
public event Func<AskContextMessage, PostContextMessage> OnAskContext;
731+
732+
/// <summary>
733+
/// Event for handling the <see cref="MessageType.RunCommand"/> message.
734+
/// </summary>
735+
public event Func<RunCommandMessage, PostResultMessage> OnRunCommand;
736+
737+
/// <summary>
738+
/// Event for handling the <see cref="MessageType.AskCommandOutput"/> message.
739+
/// </summary>
740+
public event Func<AskCommandOutputMessage, PostResultMessage> OnAskCommandOutput;
554741
}
555742

556743
/// <summary>
@@ -771,4 +958,52 @@ public async Task<PostContextMessage> AskContext(AskContextMessage message, Canc
771958

772959
return postContext;
773960
}
961+
962+
/// <summary>
963+
/// Run a command in the connected PowerShell session.
964+
/// </summary>
965+
/// <param name="message">The <see cref="MessageType.RunCommand"/> message.</param>
966+
/// <param name="cancellationToken">A cancellation token.</param>
967+
/// <returns>A <see cref="MessageType.PostResult"/> message as the response.</returns>
968+
/// <exception cref="IOException">Throws when the pipe is closed by the other side.</exception>
969+
public async Task<PostResultMessage> RunCommand(RunCommandMessage message, CancellationToken cancellationToken)
970+
{
971+
// Send the request message to the shell.
972+
SendMessage(message);
973+
974+
// Receiving response from the shell.
975+
var response = await GetMessageAsync(cancellationToken);
976+
if (response is not PostResultMessage postResult)
977+
{
978+
// Log: unexpected message. drop connection.
979+
_client.Close();
980+
throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message.");
981+
}
982+
983+
return postResult;
984+
}
985+
986+
/// <summary>
987+
/// Ask for the output of a previously run command in the connected PowerShell session.
988+
/// </summary>
989+
/// <param name="message">The <see cref="MessageType.AskCommandOutput"/> message.</param>
990+
/// <param name="cancellationToken">A cancellation token.</param>
991+
/// <returns>A <see cref="MessageType.PostResult"/> message as the response.</returns>
992+
/// <exception cref="IOException">Throws when the pipe is closed by the other side.</exception>
993+
public async Task<PostResultMessage> AskCommandOutput(AskCommandOutputMessage message, CancellationToken cancellationToken)
994+
{
995+
// Send the request message to the shell.
996+
SendMessage(message);
997+
998+
// Receiving response from the shell.
999+
var response = await GetMessageAsync(cancellationToken);
1000+
if (response is not PostResultMessage postResult)
1001+
{
1002+
// Log: unexpected message. drop connection.
1003+
_client.Close();
1004+
throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message.");
1005+
}
1006+
1007+
return postResult;
1008+
}
7741009
}

shell/AIShell.Integration/AIShell.psd1

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
PowerShellVersion = '7.4.6'
1111
PowerShellHostName = 'ConsoleHost'
1212
FunctionsToExport = @()
13-
CmdletsToExport = @('Start-AIShell','Invoke-AIShell','Resolve-Error')
13+
CmdletsToExport = @('Start-AIShell','Invoke-AIShell', 'Invoke-AICommand', 'Resolve-Error')
1414
VariablesToExport = '*'
15-
AliasesToExport = @('aish', 'askai', 'fixit')
15+
AliasesToExport = @('aish', 'askai', 'fixit', 'airun')
1616
HelpInfoURI = 'https://aka.ms/aishell-help'
1717
PrivateData = @{ PSData = @{ Prerelease = 'preview5'; ProjectUri = 'https://github.com/PowerShell/AIShell' } }
1818
}

0 commit comments

Comments
 (0)