Skip to content

Add some basic git command argument prediction #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions src/CompletionPredictor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics.CodeAnalysis;
using System.Management.Automation.Language;
using System.Management.Automation.Runspaces;
using System.Management.Automation.Subsystem.Prediction;
Expand All @@ -10,6 +11,8 @@ public partial class CompletionPredictor : ICommandPredictor, IDisposable
{
private readonly Guid _guid;
private readonly Runspace _runspace;
private readonly GitHandler _gitHandler;
private string? _cwd;
private int _lock = 1;

private static HashSet<string> s_cmdList = new(StringComparer.OrdinalIgnoreCase)
Expand All @@ -21,11 +24,13 @@ public partial class CompletionPredictor : ICommandPredictor, IDisposable
"where",
"Where-Object",
"cd",
"git",
};

internal CompletionPredictor(string guid)
{
_guid = new Guid(guid);
_gitHandler = new GitHandler();
_runspace = RunspaceFactory.CreateRunspace(InitialSessionState.CreateDefault());
_runspace.Name = nameof(CompletionPredictor);
_runspace.Open();
Expand All @@ -47,25 +52,55 @@ public SuggestionPackage GetSuggestion(PredictionClient client, PredictionContex
{
// When it ends at a white space, it would likely trigger argument completion which in most cases would be file-operation
// intensive. That's not only slow but also undesirable in most cases, so we skip it.
// But, there are exceptions for 'ForEach-Object' and 'Where-Object', where completion on member names is quite useful.
Ast lastAst = relatedAsts[^1];
var cmdName = (lastAst.Parent as CommandAst)?.CommandElements[0] as StringConstantExpressionAst;
if (cmdName is null || !s_cmdList.Contains(cmdName.Value) || !object.ReferenceEquals(lastAst, cmdName))
// But, there are exceptions for some commands, where completion on member names is quite useful.
if (!IsCommandAstWithLiteralName(context, out var cmdAst, out var nameAst)
|| !s_cmdList.TryGetValue(nameAst.Value, out string? cmd))
{
// So we stop processing unless the cursor is right after 'ForEach-Object' or 'Where-Object'.
// Stop processing if the cursor is not at the end of an allowed command.
return default;
}
}

if (tokenAtCursor is not null && tokenAtCursor.TokenFlags.HasFlag(TokenFlags.CommandName))
if (cmd is "git")
{
// Process 'git' command.
return _gitHandler.GetGitResult(cmdAst, _cwd, context, cancellationToken);
}

if (cmdAst.CommandElements.Count != 1)
{
// For commands other than 'git', we only do argument completion if the cursor is right after the command name.
return default;
}
}
else
{
// When it's a command, it would likely take too much time because the command discovery is usually expensive, so we skip it.
return default;
if (tokenAtCursor.TokenFlags.HasFlag(TokenFlags.CommandName))
{
// When it's a command, it would likely take too much time because the command discovery is usually expensive, so we skip it.
return default;
}

if (IsCommandAstWithLiteralName(context, out var cmdAst, out var nameAst)
&& string.Equals(nameAst.Value, "git", StringComparison.OrdinalIgnoreCase))
{
return _gitHandler.GetGitResult(cmdAst, _cwd, context, cancellationToken);
}
}

return GetFromTabCompletion(context, cancellationToken);
}

private bool IsCommandAstWithLiteralName(
PredictionContext context,
[NotNullWhen(true)] out CommandAst? cmdAst,
[NotNullWhen(true)] out StringConstantExpressionAst? nameAst)
{
Ast lastAst = context.RelatedAsts[^1];
cmdAst = lastAst.Parent as CommandAst;
nameAst = cmdAst?.CommandElements[0] as StringConstantExpressionAst;
return nameAst is not null;
}

private SuggestionPackage GetFromTabCompletion(PredictionContext context, CancellationToken cancellationToken)
{
// Call into PowerShell tab completion to get completion results.
Expand Down Expand Up @@ -155,11 +190,18 @@ public void Dispose()

#region "Unused interface members because this predictor doesn't process feedback"

public bool CanAcceptFeedback(PredictionClient client, PredictorFeedbackKind feedback) => false;
public bool CanAcceptFeedback(PredictionClient client, PredictorFeedbackKind feedback)
{
return feedback == PredictorFeedbackKind.CommandLineAccepted ? true : false;
}

public void OnSuggestionDisplayed(PredictionClient client, uint session, int countOrIndex) { }
public void OnSuggestionAccepted(PredictionClient client, uint session, string acceptedSuggestion) { }
public void OnCommandLineAccepted(PredictionClient client, IReadOnlyList<string> history) { }
public void OnCommandLineExecuted(PredictionClient client, string commandLine, bool success) { }
public void OnCommandLineAccepted(PredictionClient client, IReadOnlyList<string> history)
{
_gitHandler.SignalCheckForRepoUpdate();
}

#endregion;
}
1 change: 1 addition & 0 deletions src/CompletionPredictorStateSync.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private void SyncCurrentPath(Runspace source)
{
PathInfo currentPath = source.SessionStateProxy.Path.CurrentLocation;
_runspace.SessionStateProxy.Path.SetLocation(currentPath.Path);
_cwd = source.SessionStateProxy.Path.CurrentFileSystemLocation.ProviderPath;
}

private void SyncVariables(Runspace source)
Expand Down
144 changes: 144 additions & 0 deletions src/CustomHandlers/GitHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
using System.Collections.Concurrent;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem.Prediction;

namespace Microsoft.PowerShell.Predictor;

internal partial class GitHandler
{
private readonly ConcurrentDictionary<string, RepoInfo> _repos;
private readonly Dictionary<string, GitNode> _gitCmds;

internal GitHandler()
{
_repos = new(StringComparer.Ordinal);
_gitCmds = new(StringComparer.Ordinal)
{
{ "merge", new Merge() },
{ "branch", new Branch() },
{ "checkout", new Checkout() },
{ "push", new Push() },
};
}

internal void SignalCheckForRepoUpdate()
{
foreach (var repoInfo in _repos.Values)
{
repoInfo.NeedCheckForUpdate();
}
}

internal SuggestionPackage GetGitResult(CommandAst gitAst, string? cwd, PredictionContext context, CancellationToken token)
{
var elements = gitAst.CommandElements;
if (cwd is null || elements.Count is 1 || !TryConvertToText(elements, out List<string>? textElements))
{
return default;
}

RepoInfo? repoInfo = GetRepoInfo(cwd);
if (repoInfo is null || token.IsCancellationRequested)
{
return default;
}

string gitCmd = textElements[1];
string? textAtCursor = context.TokenAtCursor?.Text;
bool cursorAtGitCmd = textElements.Count is 2 && textAtCursor is not null;

if (!_gitCmds.TryGetValue(gitCmd, out GitNode? node))
{
if (cursorAtGitCmd)
{
foreach (var entry in _gitCmds)
{
if (entry.Key.StartsWith(textAtCursor!))
{
node = entry.Value;
break;
}
}
}
}

if (node is not null)
{
return node.Predict(textElements, textAtCursor, context.InputAst.Extent.Text, repoInfo, cursorAtGitCmd);
}

return default;
}

private bool TryConvertToText(
ReadOnlyCollection<CommandElementAst> elements,
[NotNullWhen(true)] out List<string>? textElements)
{
textElements = new(elements.Count);
foreach (var e in elements)
{
switch (e)
{
case StringConstantExpressionAst str:
textElements.Add(str.Value);
break;
case CommandParameterAst param:
textElements.Add(param.Extent.Text);
break;
default:
textElements = null;
return false;
}
}

return true;
}

private RepoInfo? GetRepoInfo(string cwd)
{
if (_repos.TryGetValue(cwd, out RepoInfo? repoInfo))
{
return repoInfo;
}

foreach (var entry in _repos)
{
string root = entry.Key;
if (cwd.StartsWith(root) && cwd[root.Length] == Path.DirectorySeparatorChar)
{
repoInfo = entry.Value;
break;
}
}

if (repoInfo is null)
{
string? repoRoot = FindRepoRoot(cwd);
if (repoRoot is not null)
{
repoInfo = _repos.GetOrAdd(repoRoot, new RepoInfo(repoRoot));
}
}

return repoInfo;
}

private string? FindRepoRoot(string currentLocation)
{
string? root = currentLocation;
while (root is not null)
{
string gitDir = Path.Join(root, ".git", "refs");
if (Directory.Exists(gitDir))
{
return root;
}

root = Path.GetDirectoryName(root);
}

return null;
}
}
Loading