Skip to content

Commit ca8746f

Browse files
authored
[5.1] Add | Cache TokenCredential objects to take advantage of token caching (#2380) (#2776)
* Add | Cache TokenCredential objects to take advantage of token caching (#2380)
1 parent 6fe8e21 commit ca8746f

File tree

1 file changed

+201
-33
lines changed

1 file changed

+201
-33
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 201 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2525
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
2626
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
2727
/// </summary>
28-
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
29-
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
3028
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
29+
private static readonly ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap = new();
30+
private static readonly ConcurrentDictionary<TokenCredentialKey, TokenCredentialData> s_tokenCredentialMap = new();
31+
private static SemaphoreSlim s_pcaMapModifierSemaphore = new(1, 1);
32+
private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new(1, 1);
3133
private static readonly int s_accountPwCacheTtlInHours = 2;
3234
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
3335
private static readonly string s_defaultScopeSuffix = "/.default";
3436
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
35-
private readonly SqlClientLogger _logger = new SqlClientLogger();
37+
private readonly SqlClientLogger _logger = new();
3638
private Func<DeviceCodeResult, Task> _deviceCodeFlowCallback;
3739
private ICustomWebUi _customWebUI = null;
3840
private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId;
@@ -66,6 +68,11 @@ public static void ClearUserTokenCache()
6668
{
6769
s_pcaMap.Clear();
6870
}
71+
72+
if (!s_tokenCredentialMap.IsEmpty)
73+
{
74+
s_tokenCredentialMap.Clear();
75+
}
6976
}
7077

7178
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetDeviceCodeFlowCallback/*'/>
@@ -145,50 +152,40 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
145152
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
146153
**/
147154

148-
int seperatorIndex = parameters.Authority.LastIndexOf('/');
149-
string authority = parameters.Authority.Remove(seperatorIndex + 1);
150-
string audience = parameters.Authority.Substring(seperatorIndex + 1);
155+
int separatorIndex = parameters.Authority.LastIndexOf('/');
156+
string authority = parameters.Authority.Remove(separatorIndex + 1);
157+
string audience = parameters.Authority.Substring(separatorIndex + 1);
151158
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
152159

153160
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
154161
{
155-
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
156-
{
157-
AuthorityHost = new Uri(authority),
158-
SharedTokenCacheTenantId = audience,
159-
VisualStudioCodeTenantId = audience,
160-
VisualStudioTenantId = audience,
161-
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
162-
};
163-
164-
// Optionally set clientId when available
165-
if (clientId is not null)
166-
{
167-
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
168-
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
169-
}
170-
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
162+
// Cache DefaultAzureCredenial based on scope, authority, audience, and clientId
163+
TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId);
164+
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
171165
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
172166
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
173167
}
174168

175-
TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions() { AuthorityHost = new Uri(authority) };
169+
TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) };
176170

177171
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
178172
{
179-
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
173+
// Cache ManagedIdentityCredential based on scope, authority, and clientId
174+
TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId);
175+
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
180176
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
181177
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
182178
}
183179

184180
AuthenticationResult result = null;
185181
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
186182
{
187-
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
183+
// Cache ClientSecretCredential based on scope, authority, audience, and clientId
184+
TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId);
185+
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false);
188186
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
189187
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
190188
}
191-
192189
/*
193190
* Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows
194191
* (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend
@@ -204,7 +201,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
204201
redirectUri = "http://localhost";
205202
}
206203
#endif
207-
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
204+
PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId
208205
#if NETFRAMEWORK
209206
, _iWin32WindowFunc
210207
#endif
@@ -213,7 +210,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
213210
#endif
214211
);
215212

216-
IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
213+
IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false);
217214

218215
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
219216
{
@@ -248,7 +245,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
248245
if (null != previousPw &&
249246
previousPw is byte[] previousPwBytes &&
250247
// Only get the cached token if the current password hash matches the previously used password hash
251-
currPwHash.SequenceEqual(previousPwBytes))
248+
AreEqual(currPwHash, previousPwBytes))
252249
{
253250
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
254251
}
@@ -353,7 +350,7 @@ private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlo
353350
{
354351
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
355352
{
356-
CancellationTokenSource ctsInteractive = new CancellationTokenSource();
353+
CancellationTokenSource ctsInteractive = new();
357354
#if NETCOREAPP
358355
/*
359356
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
@@ -447,16 +444,69 @@ public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec
447444
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
448445
}
449446

450-
private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
447+
private async Task<IPublicClientApplication> GetPublicClientAppInstanceAsync(PublicClientAppKey publicClientAppKey, CancellationToken cancellationToken)
451448
{
452449
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
453450
{
454-
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
455-
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
451+
await s_pcaMapModifierSemaphore.WaitAsync(cancellationToken);
452+
try
453+
{
454+
// Double-check in case another thread added it while we waited for the semaphore
455+
if (!s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance))
456+
{
457+
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
458+
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
459+
}
460+
}
461+
finally
462+
{
463+
s_pcaMapModifierSemaphore.Release();
464+
}
456465
}
466+
457467
return clientApplicationInstance;
458468
}
459469

470+
private static async Task<AccessToken> GetTokenAsync(TokenCredentialKey tokenCredentialKey, string secret,
471+
TokenRequestContext tokenRequestContext, CancellationToken cancellationToken)
472+
{
473+
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out TokenCredentialData tokenCredentialInstance))
474+
{
475+
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
476+
try
477+
{
478+
// Double-check in case another thread added it while we waited for the semaphore
479+
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out tokenCredentialInstance))
480+
{
481+
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
482+
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
483+
}
484+
}
485+
finally
486+
{
487+
s_tokenCredentialMapModifierSemaphore.Release();
488+
}
489+
}
490+
491+
if (!AreEqual(tokenCredentialInstance._secretHash, GetHash(secret)))
492+
{
493+
// If the secret hash has changed, we need to remove the old token credential instance and create a new one.
494+
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
495+
try
496+
{
497+
s_tokenCredentialMap.TryRemove(tokenCredentialKey, out _);
498+
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
499+
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
500+
}
501+
finally
502+
{
503+
s_tokenCredentialMapModifierSemaphore.Release();
504+
}
505+
}
506+
507+
return await tokenCredentialInstance._tokenCredential.GetTokenAsync(tokenRequestContext, cancellationToken);
508+
}
509+
460510
private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
461511
{
462512
return parameters.Authority + "+" + parameters.UserId;
@@ -470,6 +520,24 @@ private static byte[] GetHash(string input)
470520
return hashedBytes;
471521
}
472522

523+
private static bool AreEqual(byte[] a1, byte[] a2)
524+
{
525+
if (ReferenceEquals(a1, a2))
526+
{
527+
return true;
528+
}
529+
else if (a1 is null || a2 is null)
530+
{
531+
return false;
532+
}
533+
else if (a1.Length != a2.Length)
534+
{
535+
return false;
536+
}
537+
538+
return a1.AsSpan().SequenceEqual(a2.AsSpan());
539+
}
540+
473541
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
474542
{
475543
IPublicClientApplication publicClientApplication;
@@ -513,6 +581,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ
513581
return publicClientApplication;
514582
}
515583

584+
private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret)
585+
{
586+
if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential))
587+
{
588+
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
589+
{
590+
AuthorityHost = new Uri(tokenCredentialKey._authority),
591+
SharedTokenCacheTenantId = tokenCredentialKey._audience,
592+
VisualStudioCodeTenantId = tokenCredentialKey._audience,
593+
VisualStudioTenantId = tokenCredentialKey._audience,
594+
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
595+
};
596+
597+
// Optionally set clientId when available
598+
if (tokenCredentialKey._clientId is not null)
599+
{
600+
defaultAzureCredentialOptions.ManagedIdentityClientId = tokenCredentialKey._clientId;
601+
defaultAzureCredentialOptions.SharedTokenCacheUsername = tokenCredentialKey._clientId;
602+
defaultAzureCredentialOptions.WorkloadIdentityClientId = tokenCredentialKey._clientId;
603+
}
604+
605+
return new TokenCredentialData(new DefaultAzureCredential(defaultAzureCredentialOptions), GetHash(secret));
606+
}
607+
608+
TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };
609+
610+
if (tokenCredentialKey._tokenCredentialType == typeof(ManagedIdentityCredential))
611+
{
612+
return new TokenCredentialData(new ManagedIdentityCredential(tokenCredentialKey._clientId, tokenCredentialOptions), GetHash(secret));
613+
}
614+
else if (tokenCredentialKey._tokenCredentialType == typeof(ClientSecretCredential))
615+
{
616+
return new TokenCredentialData(new ClientSecretCredential(tokenCredentialKey._audience, tokenCredentialKey._clientId, secret, tokenCredentialOptions), GetHash(secret));
617+
}
618+
else if (tokenCredentialKey._tokenCredentialType == typeof(WorkloadIdentityCredential))
619+
{
620+
// The WorkloadIdentityCredentialOptions object initialization populates its instance members
621+
// from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE,
622+
// and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id.
623+
WorkloadIdentityCredentialOptions options = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };
624+
625+
if (tokenCredentialKey._clientId is not null)
626+
{
627+
options.ClientId = tokenCredentialKey._clientId;
628+
}
629+
630+
return new TokenCredentialData(new WorkloadIdentityCredential(options), GetHash(secret));
631+
}
632+
633+
// This should never be reached, but if it is, throw an exception that will be noticed during development
634+
throw new ArgumentException(nameof(ActiveDirectoryAuthenticationProvider));
635+
}
636+
516637
internal class PublicClientAppKey
517638
{
518639
public readonly string _authority;
@@ -572,5 +693,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app
572693
#endif
573694
).GetHashCode();
574695
}
696+
697+
internal class TokenCredentialData
698+
{
699+
public TokenCredential _tokenCredential;
700+
public byte[] _secretHash;
701+
702+
public TokenCredentialData(TokenCredential tokenCredential, byte[] secretHash)
703+
{
704+
_tokenCredential = tokenCredential;
705+
_secretHash = secretHash;
706+
}
707+
}
708+
709+
internal class TokenCredentialKey
710+
{
711+
public readonly Type _tokenCredentialType;
712+
public readonly string _authority;
713+
public readonly string _scope;
714+
public readonly string _audience;
715+
public readonly string _clientId;
716+
717+
public TokenCredentialKey(Type tokenCredentialType, string authority, string scope, string audience, string clientId)
718+
{
719+
_tokenCredentialType = tokenCredentialType;
720+
_authority = authority;
721+
_scope = scope;
722+
_audience = audience;
723+
_clientId = clientId;
724+
}
725+
726+
public override bool Equals(object obj)
727+
{
728+
if (obj != null && obj is TokenCredentialKey tcKey)
729+
{
730+
return string.CompareOrdinal(nameof(_tokenCredentialType), nameof(tcKey._tokenCredentialType)) == 0
731+
&& string.CompareOrdinal(_authority, tcKey._authority) == 0
732+
&& string.CompareOrdinal(_scope, tcKey._scope) == 0
733+
&& string.CompareOrdinal(_audience, tcKey._audience) == 0
734+
&& string.CompareOrdinal(_clientId, tcKey._clientId) == 0
735+
;
736+
}
737+
return false;
738+
}
739+
740+
public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode();
741+
}
742+
575743
}
576744
}

0 commit comments

Comments
 (0)