Skip to content

Commit dbd89af

Browse files
authored
Add argument prediction for some basic git command (#21)
1 parent f860b83 commit dbd89af

File tree

5 files changed

+980
-11
lines changed

5 files changed

+980
-11
lines changed

src/CompletionPredictor.cs

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Diagnostics.CodeAnalysis;
12
using System.Management.Automation.Language;
23
using System.Management.Automation.Runspaces;
34
using System.Management.Automation.Subsystem.Prediction;
@@ -10,6 +11,8 @@ public partial class CompletionPredictor : ICommandPredictor, IDisposable
1011
{
1112
private readonly Guid _guid;
1213
private readonly Runspace _runspace;
14+
private readonly GitHandler _gitHandler;
15+
private string? _cwd;
1316
private int _lock = 1;
1417

1518
private static HashSet<string> s_cmdList = new(StringComparer.OrdinalIgnoreCase)
@@ -21,11 +24,13 @@ public partial class CompletionPredictor : ICommandPredictor, IDisposable
2124
"where",
2225
"Where-Object",
2326
"cd",
27+
"git",
2428
};
2529

2630
internal CompletionPredictor(string guid)
2731
{
2832
_guid = new Guid(guid);
33+
_gitHandler = new GitHandler();
2934
_runspace = RunspaceFactory.CreateRunspace(InitialSessionState.CreateDefault());
3035
_runspace.Name = nameof(CompletionPredictor);
3136
_runspace.Open();
@@ -47,25 +52,55 @@ public SuggestionPackage GetSuggestion(PredictionClient client, PredictionContex
4752
{
4853
// When it ends at a white space, it would likely trigger argument completion which in most cases would be file-operation
4954
// intensive. That's not only slow but also undesirable in most cases, so we skip it.
50-
// But, there are exceptions for 'ForEach-Object' and 'Where-Object', where completion on member names is quite useful.
51-
Ast lastAst = relatedAsts[^1];
52-
var cmdName = (lastAst.Parent as CommandAst)?.CommandElements[0] as StringConstantExpressionAst;
53-
if (cmdName is null || !s_cmdList.Contains(cmdName.Value) || !object.ReferenceEquals(lastAst, cmdName))
55+
// But, there are exceptions for some commands, where completion on member names is quite useful.
56+
if (!IsCommandAstWithLiteralName(context, out var cmdAst, out var nameAst)
57+
|| !s_cmdList.TryGetValue(nameAst.Value, out string? cmd))
5458
{
55-
// So we stop processing unless the cursor is right after 'ForEach-Object' or 'Where-Object'.
59+
// Stop processing if the cursor is not at the end of an allowed command.
5660
return default;
5761
}
58-
}
5962

60-
if (tokenAtCursor is not null && tokenAtCursor.TokenFlags.HasFlag(TokenFlags.CommandName))
63+
if (cmd is "git")
64+
{
65+
// Process 'git' command.
66+
return _gitHandler.GetGitResult(cmdAst, _cwd, context, cancellationToken);
67+
}
68+
69+
if (cmdAst.CommandElements.Count != 1)
70+
{
71+
// For commands other than 'git', we only do argument completion if the cursor is right after the command name.
72+
return default;
73+
}
74+
}
75+
else
6176
{
62-
// When it's a command, it would likely take too much time because the command discovery is usually expensive, so we skip it.
63-
return default;
77+
if (tokenAtCursor.TokenFlags.HasFlag(TokenFlags.CommandName))
78+
{
79+
// When it's a command, it would likely take too much time because the command discovery is usually expensive, so we skip it.
80+
return default;
81+
}
82+
83+
if (IsCommandAstWithLiteralName(context, out var cmdAst, out var nameAst)
84+
&& string.Equals(nameAst.Value, "git", StringComparison.OrdinalIgnoreCase))
85+
{
86+
return _gitHandler.GetGitResult(cmdAst, _cwd, context, cancellationToken);
87+
}
6488
}
6589

6690
return GetFromTabCompletion(context, cancellationToken);
6791
}
6892

93+
private bool IsCommandAstWithLiteralName(
94+
PredictionContext context,
95+
[NotNullWhen(true)] out CommandAst? cmdAst,
96+
[NotNullWhen(true)] out StringConstantExpressionAst? nameAst)
97+
{
98+
Ast lastAst = context.RelatedAsts[^1];
99+
cmdAst = lastAst.Parent as CommandAst;
100+
nameAst = cmdAst?.CommandElements[0] as StringConstantExpressionAst;
101+
return nameAst is not null;
102+
}
103+
69104
private SuggestionPackage GetFromTabCompletion(PredictionContext context, CancellationToken cancellationToken)
70105
{
71106
// Call into PowerShell tab completion to get completion results.
@@ -155,11 +190,18 @@ public void Dispose()
155190

156191
#region "Unused interface members because this predictor doesn't process feedback"
157192

158-
public bool CanAcceptFeedback(PredictionClient client, PredictorFeedbackKind feedback) => false;
193+
public bool CanAcceptFeedback(PredictionClient client, PredictorFeedbackKind feedback)
194+
{
195+
return feedback == PredictorFeedbackKind.CommandLineAccepted ? true : false;
196+
}
197+
159198
public void OnSuggestionDisplayed(PredictionClient client, uint session, int countOrIndex) { }
160199
public void OnSuggestionAccepted(PredictionClient client, uint session, string acceptedSuggestion) { }
161-
public void OnCommandLineAccepted(PredictionClient client, IReadOnlyList<string> history) { }
162200
public void OnCommandLineExecuted(PredictionClient client, string commandLine, bool success) { }
201+
public void OnCommandLineAccepted(PredictionClient client, IReadOnlyList<string> history)
202+
{
203+
_gitHandler.SignalCheckForRepoUpdate();
204+
}
163205

164206
#endregion;
165207
}

src/CompletionPredictorStateSync.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ private void SyncCurrentPath(Runspace source)
149149
{
150150
PathInfo currentPath = source.SessionStateProxy.Path.CurrentLocation;
151151
_runspace.SessionStateProxy.Path.SetLocation(currentPath.Path);
152+
_cwd = source.SessionStateProxy.Path.CurrentFileSystemLocation.ProviderPath;
152153
}
153154

154155
private void SyncVariables(Runspace source)

src/CustomHandlers/GitHandler.cs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
using System.Collections.Concurrent;
2+
using System.Collections.ObjectModel;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Management.Automation.Language;
5+
using System.Management.Automation.Subsystem.Prediction;
6+
7+
namespace Microsoft.PowerShell.Predictor;
8+
9+
internal partial class GitHandler
10+
{
11+
private readonly ConcurrentDictionary<string, RepoInfo> _repos;
12+
private readonly Dictionary<string, GitNode> _gitCmds;
13+
14+
internal GitHandler()
15+
{
16+
_repos = new(StringComparer.Ordinal);
17+
_gitCmds = new(StringComparer.Ordinal)
18+
{
19+
{ "merge", new Merge() },
20+
{ "branch", new Branch() },
21+
{ "checkout", new Checkout() },
22+
{ "push", new Push() },
23+
};
24+
}
25+
26+
internal void SignalCheckForRepoUpdate()
27+
{
28+
foreach (var repoInfo in _repos.Values)
29+
{
30+
repoInfo.NeedCheckForUpdate();
31+
}
32+
}
33+
34+
internal SuggestionPackage GetGitResult(CommandAst gitAst, string? cwd, PredictionContext context, CancellationToken token)
35+
{
36+
var elements = gitAst.CommandElements;
37+
if (cwd is null || elements.Count is 1 || !TryConvertToText(elements, out List<string>? textElements))
38+
{
39+
return default;
40+
}
41+
42+
RepoInfo? repoInfo = GetRepoInfo(cwd);
43+
if (repoInfo is null || token.IsCancellationRequested)
44+
{
45+
return default;
46+
}
47+
48+
string gitCmd = textElements[1];
49+
string? textAtCursor = context.TokenAtCursor?.Text;
50+
bool cursorAtGitCmd = textElements.Count is 2 && textAtCursor is not null;
51+
52+
if (!_gitCmds.TryGetValue(gitCmd, out GitNode? node))
53+
{
54+
if (cursorAtGitCmd)
55+
{
56+
foreach (var entry in _gitCmds)
57+
{
58+
if (entry.Key.StartsWith(textAtCursor!))
59+
{
60+
node = entry.Value;
61+
break;
62+
}
63+
}
64+
}
65+
}
66+
67+
if (node is not null)
68+
{
69+
return node.Predict(textElements, textAtCursor, context.InputAst.Extent.Text, repoInfo, cursorAtGitCmd);
70+
}
71+
72+
return default;
73+
}
74+
75+
private bool TryConvertToText(
76+
ReadOnlyCollection<CommandElementAst> elements,
77+
[NotNullWhen(true)] out List<string>? textElements)
78+
{
79+
textElements = new(elements.Count);
80+
foreach (var e in elements)
81+
{
82+
switch (e)
83+
{
84+
case StringConstantExpressionAst str:
85+
textElements.Add(str.Value);
86+
break;
87+
case CommandParameterAst param:
88+
textElements.Add(param.Extent.Text);
89+
break;
90+
default:
91+
textElements = null;
92+
return false;
93+
}
94+
}
95+
96+
return true;
97+
}
98+
99+
private RepoInfo? GetRepoInfo(string cwd)
100+
{
101+
if (_repos.TryGetValue(cwd, out RepoInfo? repoInfo))
102+
{
103+
return repoInfo;
104+
}
105+
106+
foreach (var entry in _repos)
107+
{
108+
string root = entry.Key;
109+
if (cwd.StartsWith(root) && cwd[root.Length] == Path.DirectorySeparatorChar)
110+
{
111+
repoInfo = entry.Value;
112+
break;
113+
}
114+
}
115+
116+
if (repoInfo is null)
117+
{
118+
string? repoRoot = FindRepoRoot(cwd);
119+
if (repoRoot is not null)
120+
{
121+
repoInfo = _repos.GetOrAdd(repoRoot, new RepoInfo(repoRoot));
122+
}
123+
}
124+
125+
return repoInfo;
126+
}
127+
128+
private string? FindRepoRoot(string currentLocation)
129+
{
130+
string? root = currentLocation;
131+
while (root is not null)
132+
{
133+
string gitDir = Path.Join(root, ".git", "refs");
134+
if (Directory.Exists(gitDir))
135+
{
136+
return root;
137+
}
138+
139+
root = Path.GetDirectoryName(root);
140+
}
141+
142+
return null;
143+
}
144+
}

0 commit comments

Comments
 (0)