Skip to content

Add built-in tools to AIShell #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions shell/AIShell.Abstraction/NamedPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ public enum MessageType : int
PostCode = 4,
}

/// <summary>
/// Context types that can be requested by AIShell from the connected PowerShell session.
/// </summary>
public enum ContextType : int
{
/// <summary>
/// Ask for the current working directory of the shell.
/// </summary>
CurrentLocation = 0,

/// <summary>
/// Ask for the command history of the shell session.
/// </summary>
CommandHistory = 1,

/// <summary>
/// Ask for the content of the terminal window.
/// </summary>
TerminalContent = 2,

/// <summary>
/// Ask for the environment variables of the shell session.
/// </summary>
EnvironmentVariables = 3,
}

/// <summary>
/// Base class for all pipe messages.
/// </summary>
Expand Down Expand Up @@ -108,12 +134,24 @@ public AskConnectionMessage(string pipeName)
/// </summary>
public sealed class AskContextMessage : PipeMessage
{
/// <summary>
/// Gets the type of context information requested.
/// </summary>
public ContextType ContextType { get; }

/// <summary>
/// Gets the argument value associated with the current context query operation.
/// </summary>
public string[] Arguments { get; }

/// <summary>
/// Creates an instance of <see cref="AskContextMessage"/>.
/// </summary>
public AskContextMessage()
public AskContextMessage(ContextType contextType, string[] arguments = null)
: base(MessageType.AskContext)
{
ContextType = contextType;
Arguments = arguments ?? null;
}
}

Expand All @@ -125,21 +163,20 @@ public sealed class PostContextMessage : PipeMessage
/// <summary>
/// Represents a none instance to be used when the shell has no context information to return.
/// </summary>
public static readonly PostContextMessage None = new([]);
public static readonly PostContextMessage None = new(contextInfo: null);

/// <summary>
/// Gets the command history.
/// Gets the information of the requested context.
/// </summary>
public List<string> CommandHistory { get; }
public string ContextInfo { get; }

/// <summary>
/// Creates an instance of <see cref="PostContextMessage"/>.
/// </summary>
public PostContextMessage(List<string> commandHistory)
public PostContextMessage(string contextInfo)
: base(MessageType.PostContext)
{
ArgumentNullException.ThrowIfNull(commandHistory);
CommandHistory = commandHistory;
ContextInfo = contextInfo;
}
}

Expand Down
2 changes: 1 addition & 1 deletion shell/AIShell.Integration/AIShell.psm1
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ if ($null -eq $runspace) {
}

## Create the channel singleton when loading the module.
$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, [Microsoft.PowerShell.PSConsoleReadLine])
$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, $ExecutionContext, [Microsoft.PowerShell.PSConsoleReadLine])
219 changes: 213 additions & 6 deletions shell/AIShell.Integration/Channel.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using System.Diagnostics;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Reflection;
using System.Text;
using System.Management.Automation;
using System.Management.Automation.Host;
using System.Management.Automation.Runspaces;
using AIShell.Abstraction;
using Microsoft.PowerShell.Commands;
using System.Text.Json;

namespace AIShell.Integration;

Expand All @@ -15,28 +19,35 @@ public class Channel : IDisposable
private readonly string _shellPipeName;
private readonly Type _psrlType;
private readonly Runspace _runspace;
private readonly EngineIntrinsics _intrinsics;
private readonly MethodInfo _psrlInsert, _psrlRevertLine, _psrlAcceptLine;
private readonly FieldInfo _psrlHandleResizing, _psrlReadLineReady;
private readonly object _psrlSingleton;
private readonly ManualResetEvent _connSetupWaitHandler;
private readonly Predictor _predictor;
private readonly ScriptBlock _onIdleAction;
private readonly List<HistoryInfo> _commandHistory;

private PathInfo _currentLocation;
private ShellClientPipe _clientPipe;
private ShellServerPipe _serverPipe;
private bool? _setupSuccess;
private Exception _exception;
private Thread _serverThread;
private CodePostData _pendingPostCodeData;

private Channel(Runspace runspace, Type psConsoleReadLineType)
private Channel(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType)
{
ArgumentNullException.ThrowIfNull(runspace);
ArgumentNullException.ThrowIfNull(psConsoleReadLineType);

_runspace = runspace;
_intrinsics = intrinsics;
_psrlType = psConsoleReadLineType;
_connSetupWaitHandler = new ManualResetEvent(false);
_currentLocation = _intrinsics.SessionState.Path.CurrentLocation;
_runspace.AvailabilityChanged += RunspaceAvailableAction;
_intrinsics.InvokeCommand.LocationChangedAction += LocationChangedAction;

_shellPipeName = new StringBuilder(MaxNamedPipeNameSize)
.Append("pwsh_aish.")
Expand All @@ -57,13 +68,14 @@ private Channel(Runspace runspace, Type psConsoleReadLineType)
_psrlReadLineReady = _psrlType.GetField("_readLineReady", fieldFlags);
_psrlHandleResizing = _psrlType.GetField("_handlePotentialResizing", fieldFlags);

_commandHistory = [];
_predictor = new Predictor();
_onIdleAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdleHandler()");
}

public static Channel CreateSingleton(Runspace runspace, Type psConsoleReadLineType)
public static Channel CreateSingleton(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType)
{
return Singleton ??= new Channel(runspace, psConsoleReadLineType);
return Singleton ??= new Channel(runspace, intrinsics, psConsoleReadLineType);
}

public static Channel Singleton { get; private set; }
Expand Down Expand Up @@ -127,6 +139,95 @@ private async void ThreadProc()
await _serverPipe.StartProcessingAsync(ConnectionTimeout, CancellationToken.None);
}

private void LocationChangedAction(object sender, LocationChangedEventArgs e)
{
_currentLocation = e.NewPath;
}

private void RunspaceAvailableAction(object sender, RunspaceAvailabilityEventArgs e)
{
if (sender is null || e.RunspaceAvailability is not RunspaceAvailability.Available)
{
return;
}

// It's safe to get states of the PowerShell Runspace now because it's available and this event
// is handled synchronously.
// We may want to invoke command or script here, and we have to unregister ourself before doing
// that, because the invocation would change the availability of the Runspace, which will cause
// the 'AvailabilityChanged' to be fired again and re-enter our handler.
// We register ourself back after we are done with the processing.
var pwshRunspace = (Runspace)sender;
pwshRunspace.AvailabilityChanged -= RunspaceAvailableAction;

try
{
using var ps = PowerShell.Create();
ps.Runspace = pwshRunspace;

var results = ps
.AddCommand("Get-History")
.AddParameter("Count", 5)
.InvokeAndCleanup<HistoryInfo>();

if (results.Count is 0 ||
(_commandHistory.Count > 0 && _commandHistory[^1].Id == results[^1].Id))
{
// No command history yet, or no change since the last update.
return;
}

lock (_commandHistory)
{
_commandHistory.Clear();
_commandHistory.AddRange(results);
}
}
finally
{
pwshRunspace.AvailabilityChanged += RunspaceAvailableAction;
}
}

private string CaptureScreen()
{
if (!OperatingSystem.IsWindows())
{
return null;
}

try
{
PSHostRawUserInterface rawUI = _intrinsics.Host.UI.RawUI;
Coordinates start = new(0, 0), end = rawUI.CursorPosition;
end.X = rawUI.BufferSize.Width - 1;

BufferCell[,] content = rawUI.GetBufferContents(new Rectangle(start, end));
StringBuilder line = new(), buffer = new();

int rows = content.GetLength(0);
int columns = content.GetLength(1);

for (int row = 0; row < rows; row++)
{
line.Clear();
for (int column = 0; column < columns; column++)
{
line.Append(content[row, column].Character);
}

line.TrimEnd();
buffer.Append(line).Append('\n');
}

return buffer.Length is 0 ? string.Empty : buffer.ToString();
}
catch
{
return null;
}
}

internal void PostQuery(PostQueryMessage message)
{
ThrowIfNotConnected();
Expand All @@ -138,6 +239,8 @@ public void Dispose()
Reset();
_connSetupWaitHandler.Dispose();
_predictor.Unregister();
_runspace.AvailabilityChanged -= RunspaceAvailableAction;
_intrinsics.InvokeCommand.LocationChangedAction -= LocationChangedAction;
GC.SuppressFinalize(this);
}

Expand Down Expand Up @@ -257,8 +360,76 @@ private void OnPostCode(PostCodeMessage postCodeMessage)

private PostContextMessage OnAskContext(AskContextMessage askContextMessage)
{
// Not implemented yet.
return null;
const string RedactedValue = "***<sensitive data redacted>***";

ContextType type = askContextMessage.ContextType;
string[] arguments = askContextMessage.Arguments;

string contextInfo;
switch (type)
{
case ContextType.CurrentLocation:
contextInfo = JsonSerializer.Serialize(
new { Provider = _currentLocation.Provider.Name, _currentLocation.Path });
break;

case ContextType.CommandHistory:
lock (_commandHistory)
{
contextInfo = JsonSerializer.Serialize(
_commandHistory.Select(o => new { o.Id, o.CommandLine }));
}
break;

case ContextType.TerminalContent:
contextInfo = CaptureScreen();
break;

case ContextType.EnvironmentVariables:
if (arguments is { Length: > 0 })
{
var varsCopy = new Dictionary<string, string>();
foreach (string name in arguments)
{
if (!string.IsNullOrEmpty(name))
{
varsCopy.Add(name, Environment.GetEnvironmentVariable(name) is string value
? EnvVarMayBeSensitive(name) ? RedactedValue : value
: $"[env variable '{arguments}' is undefined]");
}
}

contextInfo = varsCopy.Count > 0
? JsonSerializer.Serialize(varsCopy)
: "The specified environment variable names are invalid";
}
else
{
var vars = Environment.GetEnvironmentVariables();
var varsCopy = new Dictionary<string, string>();

foreach (string key in vars.Keys)
{
varsCopy.Add(key, EnvVarMayBeSensitive(key) ? RedactedValue : (string)vars[key]);
}

contextInfo = JsonSerializer.Serialize(varsCopy);
}
break;

default:
throw new InvalidDataException($"Unknown context type '{type}'");
}

return new PostContextMessage(contextInfo);

static bool EnvVarMayBeSensitive(string key)
{
return key.Contains("key", StringComparison.OrdinalIgnoreCase) ||
key.Contains("token", StringComparison.OrdinalIgnoreCase) ||
key.Contains("pass", StringComparison.OrdinalIgnoreCase) ||
key.Contains("secret", StringComparison.OrdinalIgnoreCase);
}
}

private void OnAskConnection(ShellClientPipe clientPipe, Exception exception)
Expand Down Expand Up @@ -334,3 +505,39 @@ public void Dispose()
}

internal record CodePostData(string CodeToInsert, List<PredictionCandidate> PredictionCandidates);

internal static class ExtensionMethods
{
internal static Collection<T> InvokeAndCleanup<T>(this PowerShell ps)
{
var results = ps.Invoke<T>();
ps.Commands.Clear();

return results;
}

internal static void InvokeAndCleanup(this PowerShell ps)
{
ps.Invoke();
ps.Commands.Clear();
}

internal static void TrimEnd(this StringBuilder sb)
{
// end will point to the first non-trimmed character on the right.
int end = sb.Length - 1;
for (; end >= 0; end--)
{
if (!char.IsWhiteSpace(sb[end]))
{
break;
}
}

int index = end + 1;
if (index < sb.Length)
{
sb.Remove(index, sb.Length - index);
}
}
}
Loading