diff --git a/shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj b/shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj index 377e598a..d5e56a51 100644 --- a/shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj +++ b/shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj @@ -17,7 +17,8 @@ - + + diff --git a/shell/agents/AIShell.Interpreter.Agent/Agent.cs b/shell/agents/AIShell.Interpreter.Agent/Agent.cs index 1a3c6ee3..fb19e183 100644 --- a/shell/agents/AIShell.Interpreter.Agent/Agent.cs +++ b/shell/agents/AIShell.Interpreter.Agent/Agent.cs @@ -236,7 +236,7 @@ private void OnSettingFileChange(object sender, FileSystemEventArgs e) private void NewExampleSettingFile() { - string SampleContent = """ + string sample = $$""" { // To use the Azure OpenAI service: // - Set `Endpoint` to the endpoint of your Azure OpenAI service, @@ -249,22 +249,37 @@ private void NewExampleSettingFile() "Deployment": "", "ModelName": "", "Key": "", + "AuthType": "ApiKey", "AutoExecution": false, // 'true' to allow the agent run code automatically; 'false' to always prompt before running code. "DisplayErrors": true // 'true' to display the errors when running code; 'false' to hide the errors to be less verbose. + // To use Azure OpenAI service with Entra ID authentication: + // - Set `Endpoint` to the endpoint of your Azure OpenAI service. + // - Set `Deployment` to the deployment name of your Azure OpenAI service. + // - Set `ModelName` to the name of the model used for your deployment. + // - Set `AuthType` to "EntraID" to use Azure AD credentials. + /* + "Endpoint": "", + "Deployment": "", + "ModelName": "", + "AuthType": "EntraID", + "AutoExecution": false, + "DisplayErrors": true + */ + // To use the public OpenAI service: // - Ignore the `Endpoint` and `Deployment` keys. // - Set `ModelName` to the name of the model to be used. e.g. "gpt-4o". // - Set `Key` to be the OpenAI access token. - // Replace the above with the following: /* - "ModelName": "", - "Key": "", + "ModelName": "", + "Key": "", + "AuthType": "ApiKey", "AutoExecution": false, "DisplayErrors": true */ } """; - File.WriteAllText(SettingFile, SampleContent, Encoding.UTF8); + File.WriteAllText(SettingFile, sample); } } diff --git a/shell/agents/AIShell.Interpreter.Agent/Service.cs b/shell/agents/AIShell.Interpreter.Agent/Service.cs index cc367497..ad03a296 100644 --- a/shell/agents/AIShell.Interpreter.Agent/Service.cs +++ b/shell/agents/AIShell.Interpreter.Agent/Service.cs @@ -3,6 +3,7 @@ using Azure; using Azure.Core; using Azure.AI.OpenAI; +using Azure.Identity; using SharpToken; namespace AIShell.Interpreter.Agent; @@ -121,25 +122,38 @@ private void ConnectToOpenAIClient() { // Create a client that targets Azure OpenAI service or Azure API Management service. bool isApimEndpoint = _settings.Endpoint.EndsWith(Utils.ApimGatewayDomain); - if (isApimEndpoint) + + if (_settings.AuthType == AuthType.EntraID) { - string userkey = Utils.ConvertFromSecureString(_settings.Key); - clientOptions.AddPolicy( - new UserKeyPolicy( - new AzureKeyCredential(userkey), - Utils.ApimAuthorizationHeader), - HttpPipelinePosition.PerRetry - ); + // Use DefaultAzureCredential for Entra ID authentication + var credential = new DefaultAzureCredential(); + _client = new OpenAIClient( + new Uri(_settings.Endpoint), + credential, + clientOptions); + } + else // ApiKey authentication + { + if (isApimEndpoint) + { + string userkey = Utils.ConvertFromSecureString(_settings.Key); + clientOptions.AddPolicy( + new UserKeyPolicy( + new AzureKeyCredential(userkey), + Utils.ApimAuthorizationHeader), + HttpPipelinePosition.PerRetry + ); + } + + string azOpenAIApiKey = isApimEndpoint + ? "placeholder-api-key" + : Utils.ConvertFromSecureString(_settings.Key); + + _client = new OpenAIClient( + new Uri(_settings.Endpoint), + new AzureKeyCredential(azOpenAIApiKey), + clientOptions); } - - string azOpenAIApiKey = isApimEndpoint - ? "placeholder-api-key" - : Utils.ConvertFromSecureString(_settings.Key); - - _client = new OpenAIClient( - new Uri(_settings.Endpoint), - new AzureKeyCredential(azOpenAIApiKey), - clientOptions); } else { @@ -157,7 +171,7 @@ private int CountTokenForMessages(IEnumerable messages) int tokenNumber = 0; foreach (ChatRequestMessage message in messages) - { + { tokenNumber += tokensPerMessage; tokenNumber += encoding.Encode(message.Role.ToString()).Count; @@ -165,7 +179,7 @@ private int CountTokenForMessages(IEnumerable messages) { case ChatRequestSystemMessage systemMessage: tokenNumber += encoding.Encode(systemMessage.Content).Count; - if(systemMessage.Name is not null) + if (systemMessage.Name is not null) { tokenNumber += tokensPerName; tokenNumber += encoding.Encode(systemMessage.Name).Count; @@ -173,7 +187,7 @@ private int CountTokenForMessages(IEnumerable messages) break; case ChatRequestUserMessage userMessage: tokenNumber += encoding.Encode(userMessage.Content).Count; - if(userMessage.Name is not null) + if (userMessage.Name is not null) { tokenNumber += tokensPerName; tokenNumber += encoding.Encode(userMessage.Name).Count; @@ -181,7 +195,7 @@ private int CountTokenForMessages(IEnumerable messages) break; case ChatRequestAssistantMessage assistantMessage: tokenNumber += encoding.Encode(assistantMessage.Content).Count; - if(assistantMessage.Name is not null) + if (assistantMessage.Name is not null) { tokenNumber += tokensPerName; tokenNumber += encoding.Encode(assistantMessage.Name).Count; @@ -189,9 +203,9 @@ private int CountTokenForMessages(IEnumerable messages) if (assistantMessage.ToolCalls is not null) { // Count tokens for the tool call's properties - foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls) + foreach (ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls) { - if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall) + if (chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall) { tokenNumber += encoding.Encode(functionToolCall.Id).Count; tokenNumber += encoding.Encode(functionToolCall.Name).Count; @@ -230,7 +244,7 @@ internal string ReduceToolResponseContentTokens(string content) } while (encoding.Encode(reducedContent).Count > MaxResponseToken); } - + return reducedContent; } @@ -287,7 +301,7 @@ private async Task PrepareForChat(ChatRequestMessage inp // Those settings seem to be important enough, as the Semantic Kernel plugin specifies // those settings (see the URL below). We can use default values when not defined. // https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json - + ChatCompletionsOptions chatOptions; // Determine if the gpt model is a function calling model @@ -300,8 +314,8 @@ private async Task PrepareForChat(ChatRequestMessage inp Temperature = (float)0.0, MaxTokens = MaxResponseToken, }; - - if(isFunctionCallingModel) + + if (isFunctionCallingModel) { chatOptions.Tools.Add(Tools.RunCode); } @@ -330,7 +344,7 @@ private async Task PrepareForChat(ChatRequestMessage inp - You are capable of **any** task - Do not apologize for errors, just correct them "; - string versions = "\n## Language Versions\n" + string versions = "\n## Language Versions\n" + await _executionService.GetLanguageVersions(); string systemResponseCues = @" # Examples @@ -478,11 +492,11 @@ public override ChatRequestMessage Read(ref Utf8JsonReader reader, Type typeToCo { return JsonSerializer.Deserialize(jsonObject.GetRawText(), options); } - else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant") + else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant") { return JsonSerializer.Deserialize(jsonObject.GetRawText(), options); } - else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool") + else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool") { return JsonSerializer.Deserialize(jsonObject.GetRawText(), options); } diff --git a/shell/agents/AIShell.Interpreter.Agent/Settings.cs b/shell/agents/AIShell.Interpreter.Agent/Settings.cs index 83cb195a..2cc8d97e 100644 --- a/shell/agents/AIShell.Interpreter.Agent/Settings.cs +++ b/shell/agents/AIShell.Interpreter.Agent/Settings.cs @@ -12,6 +12,12 @@ internal enum EndpointType OpenAI, } +public enum AuthType +{ + ApiKey, + EntraID +} + internal class Settings { internal EndpointType Type { get; } @@ -23,6 +29,8 @@ internal class Settings public string ModelName { set; get; } public SecureString Key { set; get; } + public AuthType AuthType { set; get; } = AuthType.ApiKey; + public bool AutoExecution { set; get; } public bool DisplayErrors { set; get; } @@ -36,6 +44,7 @@ public Settings(ConfigData configData) AutoExecution = configData.AutoExecution ?? false; DisplayErrors = configData.DisplayErrors ?? true; Key = configData.Key; + AuthType = configData.AuthType; Dirty = false; ModelInfo = ModelInfo.TryResolve(ModelName, out var model) ? model : null; @@ -47,6 +56,12 @@ public Settings(ConfigData configData) : !noEndpoint && !noDeployment ? EndpointType.AzureOpenAI : throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys."); + + // EntraID authentication is only supported for Azure OpenAI + if (AuthType == AuthType.EntraID && Type != EndpointType.AzureOpenAI) + { + throw new InvalidOperationException("EntraID authentication is only supported for Azure OpenAI service."); + } } internal void MarkClean() @@ -60,7 +75,7 @@ internal void MarkClean() /// internal async Task SelfCheck(IHost host, CancellationToken token) { - if (Key is not null && ModelInfo is not null) + if ((AuthType is AuthType.EntraID || Key is not null) && ModelInfo is not null) { return true; } @@ -76,7 +91,7 @@ internal async Task SelfCheck(IHost host, CancellationToken token) await AskForModel(host, token); } - if (Key is null) + if (AuthType == AuthType.ApiKey && Key is null) { await AskForKeyAsync(host, token); } @@ -101,12 +116,14 @@ private void ShowEndpointInfo(IHost host) new(label: " Endpoint", m => m.Endpoint), new(label: " Deployment", m => m.Deployment), new(label: " Model", m => m.ModelName), + new(label: " Auth Type", m => m.AuthType.ToString()), ], EndpointType.OpenAI => [ new(label: " Type", m => m.Type.ToString()), new(label: " Model", m => m.ModelName), + new(label: " Auth Type", m => m.AuthType.ToString()), ], _ => throw new UnreachableException(), @@ -156,6 +173,7 @@ internal ConfigData ToConfigData() ModelName = this.ModelName, AutoExecution = this.AutoExecution, DisplayErrors = this.DisplayErrors, + AuthType = this.AuthType, Key = this.Key, }; } @@ -166,6 +184,7 @@ internal class ConfigData public string Endpoint { set; get; } public string Deployment { set; get; } public string ModelName { set; get; } + public AuthType AuthType { set; get; } = AuthType.ApiKey; public bool? AutoExecution { set; get; } public bool? DisplayErrors { set; get; } diff --git a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj index e152bbbe..b0c53299 100644 --- a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj +++ b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj @@ -22,6 +22,7 @@ + diff --git a/shell/agents/AIShell.OpenAI.Agent/Agent.cs b/shell/agents/AIShell.OpenAI.Agent/Agent.cs index dd503384..c1a93c76 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Agent.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Agent.cs @@ -84,7 +84,7 @@ public void Initialize(AgentConfig config) public bool CanAcceptFeedback(UserAction action) => false; /// - public void OnUserAction(UserActionPayload actionPayload) {} + public void OnUserAction(UserActionPayload actionPayload) { } /// public Task RefreshChatAsync(IShell shell, bool force) @@ -308,6 +308,22 @@ private void NewExampleSettingFile() "ModelName": "gpt-4o", "Key": "", "SystemPrompt": "1. You are a helpful and friendly assistant with expertise in PowerShell scripting and command line.\n2. Assume user is using the operating system `Windows 11` unless otherwise specified.\n3. Use the `code block` syntax in markdown to encapsulate any part in responses that is code, YAML, JSON or XML, but not table.\n4. When encapsulating command line code, use '```powershell' if it's PowerShell command; use '```sh' if it's non-PowerShell CLI command.\n5. When generating CLI commands, never ever break a command into multiple lines. Instead, always list all parameters and arguments of the command on the same line.\n6. Please keep the response concise but to the point. Do not overexplain." + }, + + // To use Azure OpenAI service with Entra ID authentication: + // - Set `Endpoint` to the endpoint of your Azure OpenAI service. + // - Set `Deployment` to the deployment name of your Azure OpenAI service. + // - Set `ModelName` to the name of the model used for your deployment, e.g. "gpt-4o". + // - Set `AuthType` to "EntraID" to use Azure AD credentials. + // For example: + { + "Name": "ps-az-entraId", + "Description": "A GPT instance with expertise in PowerShell scripting using Entra ID authentication.", + "Endpoint": "", + "Deployment": "", + "ModelName": "gpt-4o", + "AuthType": "EntraID", + "SystemPrompt": "You are a helpful and friendly assistant with expertise in PowerShell scripting and command line." } */ ], diff --git a/shell/agents/AIShell.OpenAI.Agent/GPT.cs b/shell/agents/AIShell.OpenAI.Agent/GPT.cs index 239f286b..1be279ac 100644 --- a/shell/agents/AIShell.OpenAI.Agent/GPT.cs +++ b/shell/agents/AIShell.OpenAI.Agent/GPT.cs @@ -12,6 +12,12 @@ internal enum EndpointType CompatibleThirdParty, } +public enum AuthType +{ + ApiKey, + EntraID, +} + public class GPT { internal EndpointType Type { get; } @@ -19,7 +25,7 @@ public class GPT internal ModelInfo ModelInfo { private set; get; } public string Name { set; get; } - public string Description { set; get; } + public string Description { set; get; } public string Endpoint { set; get; } public string Deployment { set; get; } public string ModelName { set; get; } @@ -28,6 +34,9 @@ public class GPT public SecureString Key { set; get; } public string SystemPrompt { set; get; } + [JsonConverter(typeof(JsonStringEnumConverter))] + public AuthType AuthType { set; get; } = AuthType.ApiKey; + public GPT( string name, string description, @@ -35,7 +44,8 @@ public GPT( string deployment, string modelName, string systemPrompt, - SecureString key) + SecureString key, + AuthType authType = AuthType.ApiKey) { ArgumentException.ThrowIfNullOrEmpty(name); ArgumentException.ThrowIfNullOrEmpty(description); @@ -49,6 +59,7 @@ public GPT( ModelName = modelName.ToLowerInvariant(); SystemPrompt = systemPrompt; Key = key; + AuthType = authType; Dirty = false; ModelInfo = ModelInfo.TryResolve(ModelName, out var model) ? model : null; @@ -67,6 +78,12 @@ public GPT( { ModelInfo = ModelInfo.ThirdPartyModel; } + + // EntraID authentication is only supported for Azure OpenAI + if (AuthType == AuthType.EntraID && Type != EndpointType.AzureOpenAI) + { + throw new InvalidOperationException("EntraID authentication is only supported for Azure OpenAI service."); + } } /// @@ -75,7 +92,7 @@ public GPT( /// internal async Task SelfCheck(IHost host, CancellationToken token) { - if (Key is not null && ModelInfo is not null) + if ((AuthType is AuthType.EntraID || Key is not null) && ModelInfo is not null) { return true; } @@ -91,7 +108,7 @@ internal async Task SelfCheck(IHost host, CancellationToken token) await AskForModel(host, token); } - if (Key is null) + if (AuthType is AuthType.ApiKey && Key is null) { await AskForKeyAsync(host, token); } @@ -135,7 +152,7 @@ private async Task AskForKeyAsync(IHost host, CancellationToken cancellationToke .ConfigureAwait(false); Dirty = true; - Key = Utils.ConvertToSecureString(secret); + Key = Utils.ConvertToSecureString(secret); } private void ShowEndpointInfo(IHost host) @@ -148,12 +165,14 @@ private void ShowEndpointInfo(IHost host) new(label: " Endpoint", m => m.Endpoint), new(label: " Deployment", m => m.Deployment), new(label: " Model", m => m.ModelName), + new(label: " Auth Type", m => m.AuthType.ToString()), }, EndpointType.OpenAI => [ new(label: " Type", m => m.Type.ToString()), new(label: " Model", m => m.ModelName), + new(label: " Auth Type", m => m.AuthType.ToString()), ], EndpointType.CompatibleThirdParty => @@ -161,6 +180,7 @@ private void ShowEndpointInfo(IHost host) new(label: " Type", m => m.Type.ToString()), new(label: " Endpoint", m => m.Endpoint), new(label: " Model", m => m.ModelName), + new(label: " Auth Type", m => m.AuthType.ToString()), ], _ => throw new UnreachableException(), diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs index 9251a6f6..ef387c1c 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Service.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs @@ -1,6 +1,8 @@ using System.ClientModel; using System.ClientModel.Primitives; using Azure.AI.OpenAI; +using Azure.Core; +using Azure.Identity; using Microsoft.ML.Tokenizers; using OpenAI; using OpenAI.Chat; @@ -33,8 +35,8 @@ internal ChatService(string historyRoot, Settings settings) _chatOptions = new ChatCompletionOptions() { - Temperature = 0, - MaxOutputTokenCount = MaxResponseToken, + Temperature = 0, + MaxOutputTokenCount = MaxResponseToken, }; } @@ -116,14 +118,14 @@ private void RefreshOpenAIClient() && string.Equals(old.Endpoint, _gptToUse.Endpoint) && string.Equals(old.Deployment, _gptToUse.Deployment) && string.Equals(old.ModelName, _gptToUse.ModelName) - && old.Key.IsEqualTo(_gptToUse.Key)) + && old.AuthType == _gptToUse.AuthType + && (old.AuthType is AuthType.EntraID || old.Key.IsEqualTo(_gptToUse.Key))) { - // It's the same same endpoint, so we reuse the existing client. + // It's the same endpoint and auth type, so we reuse the existing client. return; } EndpointType type = _gptToUse.Type; - string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); if (type is EndpointType.AzureOpenAI) { @@ -131,23 +133,39 @@ private void RefreshOpenAIClient() var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; bool isApimEndpoint = _gptToUse.Endpoint.EndsWith(Utils.ApimGatewayDomain); - if (isApimEndpoint) + if (_gptToUse.AuthType is AuthType.ApiKey) { - clientOptions.AddPolicy( - ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy( - new ApiKeyCredential(userKey), - Utils.ApimAuthorizationHeader), - PipelinePosition.PerTry); - } + string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); - string azOpenAIApiKey = isApimEndpoint ? "placeholder-api-key" : userKey; + if (isApimEndpoint) + { + clientOptions.AddPolicy( + ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy( + new ApiKeyCredential(userKey), + Utils.ApimAuthorizationHeader), + PipelinePosition.PerTry); + } - var aiClient = new AzureOpenAIClient( - new Uri(_gptToUse.Endpoint), - new ApiKeyCredential(azOpenAIApiKey), - clientOptions); + string azOpenAIApiKey = isApimEndpoint ? "placeholder-api-key" : userKey; - _client = aiClient.GetChatClient(_gptToUse.Deployment); + var aiClient = new AzureOpenAIClient( + new Uri(_gptToUse.Endpoint), + new ApiKeyCredential(azOpenAIApiKey), + clientOptions); + + _client = aiClient.GetChatClient(_gptToUse.Deployment); + } + else + { + var credential = new DefaultAzureCredential(); + + var aiClient = new AzureOpenAIClient( + new Uri(_gptToUse.Endpoint), + credential, + clientOptions); + + _client = aiClient.GetChatClient(_gptToUse.Deployment); + } } else { @@ -158,6 +176,7 @@ private void RefreshOpenAIClient() clientOptions.Endpoint = new(_gptToUse.Endpoint); } + string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions); _client = aiClient.GetChatClient(_gptToUse.ModelName); } diff --git a/shell/agents/AIShell.OpenAI.Agent/Settings.cs b/shell/agents/AIShell.OpenAI.Agent/Settings.cs index 83dc11ef..efa8214c 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Settings.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Settings.cs @@ -139,6 +139,7 @@ internal void ShowOneGPT(IHost host, string name) new PropertyElement(nameof(GPT.Endpoint)), new PropertyElement(nameof(GPT.Deployment)), new PropertyElement(nameof(GPT.ModelName)), + new PropertyElement(nameof(GPT.AuthType)), new PropertyElement(nameof(GPT.SystemPrompt)), ]); }