Skip to content

Commit 175fba3

Browse files
authored
Support AzCommand using AZURE_CREDENTIALS (#204)
* Support AzCommand using AZURE_CREDENTIALS * spelling * concurrency
1 parent 50ab2ab commit 175fba3

File tree

6 files changed

+393
-1
lines changed

6 files changed

+393
-1
lines changed

src/Commands/Extension/AzCommand.cs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using AzureMcp.Arguments.Extension;
88
using AzureMcp.Models.Argument;
99
using AzureMcp.Models.Command;
10+
using AzureMcp.Services.Azure.Authentication;
1011
using AzureMcp.Services.Interfaces;
1112
using Microsoft.Extensions.Logging;
1213
using ModelContextProtocol.Server;
@@ -19,7 +20,8 @@ public sealed class AzCommand(ILogger<AzCommand> logger, int processTimeoutSecon
1920
private readonly int _processTimeoutSeconds = processTimeoutSeconds;
2021
private readonly Option<string> _commandOption = ArgumentDefinitions.Extension.Az.Command.ToOption();
2122
private static string? _cachedAzPath;
22-
23+
private volatile bool _isAuthenticated = false;
24+
private static readonly SemaphoreSlim _authSemaphore = new(1, 1);
2325

2426
protected override string GetCommandName() => "az";
2527

@@ -105,6 +107,55 @@ protected override AzArguments BindArguments(ParseResult parseResult)
105107
return null;
106108
}
107109

110+
private async Task<bool> AuthenticateWithAzureCredentialsAsync(IExternalProcessService processService, ILogger logger)
111+
{
112+
if (_isAuthenticated)
113+
{
114+
Console.WriteLine("Already authenticated with Azure CLI.1");
115+
return true;
116+
}
117+
118+
try
119+
{
120+
// Check if the semaphore is already acquired to avoid re-authentication
121+
bool isAcquired = await _authSemaphore.WaitAsync(1000);
122+
if (!isAcquired || _isAuthenticated)
123+
{
124+
return _isAuthenticated;
125+
}
126+
var credentials = AuthenticationUtils.GetAzureCredentials(logger);
127+
if (credentials == null)
128+
{
129+
logger.LogWarning("Invalid AZURE_CREDENTIALS format. Skipping authentication. Ensure it contains clientId, clientSecret, and tenantId.");
130+
return false;
131+
}
132+
133+
var azPath = FindAzCliPath() ?? throw new FileNotFoundException("Azure CLI executable not found in PATH or common installation locations. Please ensure Azure CLI is installed.");
134+
135+
var loginCommand = $"login --service-principal -u {credentials.ClientId} -p {credentials.ClientSecret} --tenant {credentials.TenantId}";
136+
var result = await processService.ExecuteAsync(azPath, loginCommand, 60);
137+
138+
if (result.ExitCode != 0)
139+
{
140+
logger.LogWarning("Failed to authenticate with Azure CLI. Error: {Error}", result.Error);
141+
return false;
142+
}
143+
144+
_isAuthenticated = true;
145+
logger.LogInformation("Successfully authenticated with Azure CLI using service principal.");
146+
return true;
147+
}
148+
catch (Exception ex)
149+
{
150+
logger.LogWarning(ex, "Error during service principal authentication. Command will proceed without authentication.");
151+
return false;
152+
}
153+
finally
154+
{
155+
_authSemaphore.Release();
156+
}
157+
}
158+
108159
[McpServerTool(Destructive = true, ReadOnly = false)]
109160
public override async Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult)
110161
{
@@ -121,6 +172,9 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
121172
var command = args.Command;
122173
var processService = context.GetService<IExternalProcessService>();
123174

175+
// Try to authenticate, but continue even if it fails
176+
await AuthenticateWithAzureCredentialsAsync(processService, _logger);
177+
124178
var azPath = FindAzCliPath() ?? throw new FileNotFoundException("Azure CLI executable not found in PATH or common installation locations. Please ensure Azure CLI is installed.");
125179
var result = await processService.ExecuteAsync(azPath, command, _processTimeoutSeconds);
126180

src/Commands/JsonSourceGenerationContext.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text.Json.Serialization;
55
using AzureMcp.Commands;
66
using AzureMcp.Commands.Group;
7+
using AzureMcp.Models;
78

89
namespace AzureMcp;
910

@@ -12,6 +13,7 @@ namespace AzureMcp;
1213
[JsonSerializable(typeof(JsonElement))]
1314
[JsonSerializable(typeof(List<string>))]
1415
[JsonSerializable(typeof(List<JsonNode>))]
16+
[JsonSerializable(typeof(AzureCredentials))]
1517
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)]
1618
internal partial class JsonSourceGenerationContext : JsonSerializerContext
1719
{

src/Models/AzureCredentials.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System.Text.Json.Serialization;
5+
6+
namespace AzureMcp.Models;
7+
8+
public record AzureCredentials(
9+
[property: JsonPropertyName("clientId")] string ClientId,
10+
[property: JsonPropertyName("clientSecret")] string ClientSecret,
11+
[property: JsonPropertyName("tenantId")] string TenantId,
12+
[property: JsonPropertyName("subscriptionId")] string? SubscriptionId = null
13+
);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System.Text.Json;
5+
using AzureMcp.Models;
6+
using Microsoft.Extensions.Logging;
7+
8+
namespace AzureMcp.Services.Azure.Authentication;
9+
10+
/// <summary>
11+
/// A utility class for handling Azure authentication.
12+
/// </summary>
13+
14+
public static class AuthenticationUtils
15+
{
16+
/// <summary>
17+
/// Fetches the Azure credentials from the environment variable AZURE_CREDENTIALS.
18+
/// </summary>
19+
public static AzureCredentials? GetAzureCredentials(ILogger logger)
20+
{
21+
var credentialsJson = Environment.GetEnvironmentVariable("AZURE_CREDENTIALS");
22+
if (string.IsNullOrEmpty(credentialsJson))
23+
{
24+
return null;
25+
}
26+
27+
try
28+
{
29+
// Use source-generated serialization to avoid trimmer warnings
30+
var credentials = JsonSerializer.Deserialize(credentialsJson, JsonSourceGenerationContext.Default.AzureCredentials);
31+
if (credentials == null)
32+
{
33+
logger.LogWarning("Invalid AZURE_CREDENTIALS format. Ensure it contains clientId, clientSecret, and tenantId.");
34+
return null;
35+
}
36+
return credentials;
37+
}
38+
catch (JsonException ex)
39+
{
40+
logger.LogWarning(ex, "Failed to deserialize AZURE_CREDENTIALS. Ensure it contains clientId, clientSecret, and tenantId.");
41+
return null;
42+
}
43+
}
44+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System.CommandLine.Parsing;
5+
using System.Text.Json;
6+
using System.Text.Json.Serialization;
7+
using AzureMcp.Commands.Extension;
8+
using AzureMcp.Models.Command;
9+
using AzureMcp.Services.Interfaces;
10+
using Microsoft.Extensions.DependencyInjection;
11+
using Microsoft.Extensions.Logging;
12+
using NSubstitute;
13+
using Xunit;
14+
15+
namespace AzureMcp.Tests.Commands.Extension;
16+
17+
public sealed class AzCommandTests
18+
{
19+
private readonly IServiceProvider _serviceProvider;
20+
private readonly IExternalProcessService _processService;
21+
private readonly ILogger<AzCommand> _logger;
22+
23+
public AzCommandTests()
24+
{
25+
_processService = Substitute.For<IExternalProcessService>();
26+
_logger = Substitute.For<ILogger<AzCommand>>();
27+
28+
var collection = new ServiceCollection();
29+
collection.AddSingleton(_processService);
30+
_serviceProvider = collection.BuildServiceProvider();
31+
}
32+
33+
[Fact]
34+
public void Execute_ReturnsArguments()
35+
{
36+
var command = new AzCommand(_logger);
37+
var arguments = command.GetArguments();
38+
39+
Assert.NotNull(arguments);
40+
}
41+
42+
[Fact]
43+
public async Task ExecuteAsync_ReturnsSuccessResult_WhenCommandExecutesSuccessfully()
44+
{
45+
using (new TestEnvVar(new Dictionary<string, string>
46+
{
47+
{ "AZURE_CREDENTIALS", """{"clientId": "myClientId","clientSecret": "myClientSecret","subscriptionId": "mySubscriptionID","tenantId": "myTenantId"}""" }
48+
}))
49+
{
50+
// Arrange
51+
var command = new AzCommand(_logger);
52+
var parser = new Parser(command.GetCommand());
53+
var args = parser.Parse("--command \"group list\"");
54+
var context = new CommandContext(_serviceProvider);
55+
56+
var expectedOutput = """{"value":[{"id":"/subscriptions/12345678-1234-1234-1234-123456789012/resourceGroups/test-rg","name":"test-rg","type":"Microsoft.Resources/resourceGroups","location":"eastus","properties":{"provisioningState":"Succeeded"}}]}""";
57+
var expectedJson = JsonDocument.Parse(expectedOutput).RootElement.Clone();
58+
59+
_processService.ExecuteAsync(
60+
Arg.Any<string>(),
61+
"group list",
62+
Arg.Any<int>(),
63+
Arg.Any<IEnumerable<string>>())
64+
.Returns(new ProcessResult(0, expectedOutput, string.Empty, "group list"));
65+
66+
_processService.ParseJsonOutput(Arg.Any<ProcessResult>())
67+
.Returns(expectedJson);
68+
69+
// Act
70+
var response = await command.ExecuteAsync(context, args);
71+
72+
// Assert
73+
Assert.NotNull(response);
74+
Assert.Equal(200, response.Status);
75+
Assert.NotNull(response.Results);
76+
77+
// Verify the ProcessService was called with expected parameters
78+
await _processService.Received().ExecuteAsync(
79+
Arg.Any<string>(),
80+
"group list",
81+
Arg.Any<int>(),
82+
Arg.Any<IEnumerable<string>>());
83+
84+
await _processService.Received().ExecuteAsync(
85+
Arg.Any<string>(),
86+
$"login --service-principal -u myClientId -p myClientSecret --tenant myTenantId",
87+
Arg.Any<int>(),
88+
Arg.Any<IEnumerable<string>>());
89+
}
90+
}
91+
92+
[Fact]
93+
public async Task ExecuteAsync_ReturnsErrorResponse_WhenCommandFails()
94+
{
95+
// Arrange
96+
var command = new AzCommand(_logger);
97+
var parser = new Parser(command.GetCommand());
98+
var args = parser.Parse("--command \"group invalid-command\"");
99+
var context = new CommandContext(_serviceProvider);
100+
101+
var errorMessage = "Error: az group: 'invalid-command' is not an az command.";
102+
103+
_processService.ExecuteAsync(
104+
Arg.Any<string>(),
105+
"group invalid-command",
106+
Arg.Any<int>(),
107+
Arg.Any<IEnumerable<string>>())
108+
.Returns(new ProcessResult(1, string.Empty, errorMessage, "group invalid-command"));
109+
110+
// Act
111+
var response = await command.ExecuteAsync(context, args);
112+
113+
// Assert
114+
Assert.NotNull(response);
115+
Assert.Equal(500, response.Status);
116+
Assert.Equal(errorMessage, response.Message);
117+
}
118+
119+
[Fact]
120+
public async Task ExecuteAsync_HandlesException_AndSetsException()
121+
{
122+
// Arrange
123+
var command = new AzCommand(_logger);
124+
var parser = new Parser(command.GetCommand());
125+
var args = parser.Parse("--command \"group list\"");
126+
var context = new CommandContext(_serviceProvider);
127+
128+
var exceptionMessage = "Azure CLI executable not found";
129+
130+
_processService.ExecuteAsync(
131+
Arg.Any<string>(),
132+
"group list",
133+
Arg.Any<int>(),
134+
Arg.Any<IEnumerable<string>>())
135+
.Returns(Task.FromException<ProcessResult>(new FileNotFoundException(exceptionMessage)));
136+
137+
// Act
138+
var response = await command.ExecuteAsync(context, args);
139+
140+
// Assert
141+
Assert.NotNull(response);
142+
Assert.Equal(500, response.Status);
143+
Assert.Contains("To mitigate this issue", response.Message);
144+
Assert.Contains(exceptionMessage, response.Message);
145+
}
146+
147+
[Fact]
148+
public async Task ExecuteAsync_ReturnsBadRequest_WhenMissingRequiredArguments()
149+
{
150+
// Arrange
151+
var command = new AzCommand(_logger);
152+
var parser = new Parser(command.GetCommand());
153+
var args = parser.Parse(""); // No command specified
154+
var context = new CommandContext(_serviceProvider);
155+
156+
// Act
157+
var response = await command.ExecuteAsync(context, args);
158+
159+
// Assert
160+
Assert.NotNull(response);
161+
Assert.Equal(400, response.Status);
162+
}
163+
164+
private sealed class AzResult
165+
{
166+
[JsonPropertyName("value")]
167+
public List<JsonElement> Value { get; set; } = new();
168+
}
169+
}

0 commit comments

Comments
 (0)