@@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
25
25
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
26
26
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
27
27
/// </summary>
28
- private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
29
- = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
30
28
private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
29
+ private static readonly ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap = new ( ) ;
30
+ private static readonly ConcurrentDictionary < TokenCredentialKey , TokenCredentialData > s_tokenCredentialMap = new ( ) ;
31
+ private static SemaphoreSlim s_pcaMapModifierSemaphore = new ( 1 , 1 ) ;
32
+ private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new ( 1 , 1 ) ;
31
33
private static readonly int s_accountPwCacheTtlInHours = 2 ;
32
34
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
33
35
private static readonly string s_defaultScopeSuffix = "/.default" ;
34
36
private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
35
- private readonly SqlClientLogger _logger = new SqlClientLogger ( ) ;
37
+ private readonly SqlClientLogger _logger = new ( ) ;
36
38
private Func < DeviceCodeResult , Task > _deviceCodeFlowCallback ;
37
39
private ICustomWebUi _customWebUI = null ;
38
40
private readonly string _applicationClientId = ActiveDirectoryAuthentication . AdoClientId ;
@@ -66,6 +68,11 @@ public static void ClearUserTokenCache()
66
68
{
67
69
s_pcaMap . Clear ( ) ;
68
70
}
71
+
72
+ if ( ! s_tokenCredentialMap . IsEmpty )
73
+ {
74
+ s_tokenCredentialMap . Clear ( ) ;
75
+ }
69
76
}
70
77
71
78
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetDeviceCodeFlowCallback/*'/>
@@ -145,50 +152,40 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
145
152
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
146
153
**/
147
154
148
- int seperatorIndex = parameters . Authority . LastIndexOf ( '/' ) ;
149
- string authority = parameters . Authority . Remove ( seperatorIndex + 1 ) ;
150
- string audience = parameters . Authority . Substring ( seperatorIndex + 1 ) ;
155
+ int separatorIndex = parameters . Authority . LastIndexOf ( '/' ) ;
156
+ string authority = parameters . Authority . Remove ( separatorIndex + 1 ) ;
157
+ string audience = parameters . Authority . Substring ( separatorIndex + 1 ) ;
151
158
string clientId = string . IsNullOrWhiteSpace ( parameters . UserId ) ? null : parameters . UserId ;
152
159
153
160
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDefault )
154
161
{
155
- DefaultAzureCredentialOptions defaultAzureCredentialOptions = new ( )
156
- {
157
- AuthorityHost = new Uri ( authority ) ,
158
- SharedTokenCacheTenantId = audience ,
159
- VisualStudioCodeTenantId = audience ,
160
- VisualStudioTenantId = audience ,
161
- ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
162
- } ;
163
-
164
- // Optionally set clientId when available
165
- if ( clientId is not null )
166
- {
167
- defaultAzureCredentialOptions . ManagedIdentityClientId = clientId ;
168
- defaultAzureCredentialOptions . SharedTokenCacheUsername = clientId ;
169
- }
170
- AccessToken accessToken = await new DefaultAzureCredential ( defaultAzureCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
162
+ // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId
163
+ TokenCredentialKey tokenCredentialKey = new ( typeof ( DefaultAzureCredential ) , authority , scope , audience , clientId ) ;
164
+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , string . Empty , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
171
165
SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
172
166
return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
173
167
}
174
168
175
- TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions ( ) { AuthorityHost = new Uri ( authority ) } ;
169
+ TokenCredentialOptions tokenCredentialOptions = new ( ) { AuthorityHost = new Uri ( authority ) } ;
176
170
177
171
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryManagedIdentity || parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryMSI )
178
172
{
179
- AccessToken accessToken = await new ManagedIdentityCredential ( clientId , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
173
+ // Cache ManagedIdentityCredential based on scope, authority, and clientId
174
+ TokenCredentialKey tokenCredentialKey = new ( typeof ( ManagedIdentityCredential ) , authority , scope , string . Empty , clientId ) ;
175
+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , string . Empty , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
180
176
SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
181
177
return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
182
178
}
183
179
184
180
AuthenticationResult result = null ;
185
181
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
186
182
{
187
- AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
183
+ // Cache ClientSecretCredential based on scope, authority, audience, and clientId
184
+ TokenCredentialKey tokenCredentialKey = new ( typeof ( ClientSecretCredential ) , authority , scope , audience , clientId ) ;
185
+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , parameters . Password , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
188
186
SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
189
187
return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
190
188
}
191
-
192
189
/*
193
190
* Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows
194
191
* (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend
@@ -204,7 +201,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
204
201
redirectUri = "http://localhost" ;
205
202
}
206
203
#endif
207
- PublicClientAppKey pcaKey = new PublicClientAppKey ( parameters . Authority , redirectUri , _applicationClientId
204
+ PublicClientAppKey pcaKey = new ( parameters . Authority , redirectUri , _applicationClientId
208
205
#if NETFRAMEWORK
209
206
, _iWin32WindowFunc
210
207
#endif
@@ -213,7 +210,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
213
210
#endif
214
211
) ;
215
212
216
- IPublicClientApplication app = GetPublicClientAppInstance ( pcaKey ) ;
213
+ IPublicClientApplication app = await GetPublicClientAppInstanceAsync ( pcaKey , cts . Token ) . ConfigureAwait ( false ) ;
217
214
218
215
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
219
216
{
@@ -248,7 +245,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
248
245
if ( null != previousPw &&
249
246
previousPw is byte [ ] previousPwBytes &&
250
247
// Only get the cached token if the current password hash matches the previously used password hash
251
- currPwHash . SequenceEqual ( previousPwBytes ) )
248
+ AreEqual ( currPwHash , previousPwBytes ) )
252
249
{
253
250
result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
254
251
}
@@ -353,7 +350,7 @@ private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlo
353
350
{
354
351
if ( authenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive )
355
352
{
356
- CancellationTokenSource ctsInteractive = new CancellationTokenSource ( ) ;
353
+ CancellationTokenSource ctsInteractive = new ( ) ;
357
354
#if NETCOREAPP
358
355
/*
359
356
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
@@ -447,16 +444,69 @@ public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec
447
444
=> _acquireAuthorizationCodeAsyncCallback . Invoke ( authorizationUri , redirectUri , cancellationToken ) ;
448
445
}
449
446
450
- private IPublicClientApplication GetPublicClientAppInstance ( PublicClientAppKey publicClientAppKey )
447
+ private async Task < IPublicClientApplication > GetPublicClientAppInstanceAsync ( PublicClientAppKey publicClientAppKey , CancellationToken cancellationToken )
451
448
{
452
449
if ( ! s_pcaMap . TryGetValue ( publicClientAppKey , out IPublicClientApplication clientApplicationInstance ) )
453
450
{
454
- clientApplicationInstance = CreateClientAppInstance ( publicClientAppKey ) ;
455
- s_pcaMap . TryAdd ( publicClientAppKey , clientApplicationInstance ) ;
451
+ await s_pcaMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
452
+ try
453
+ {
454
+ // Double-check in case another thread added it while we waited for the semaphore
455
+ if ( ! s_pcaMap . TryGetValue ( publicClientAppKey , out clientApplicationInstance ) )
456
+ {
457
+ clientApplicationInstance = CreateClientAppInstance ( publicClientAppKey ) ;
458
+ s_pcaMap . TryAdd ( publicClientAppKey , clientApplicationInstance ) ;
459
+ }
460
+ }
461
+ finally
462
+ {
463
+ s_pcaMapModifierSemaphore . Release ( ) ;
464
+ }
456
465
}
466
+
457
467
return clientApplicationInstance ;
458
468
}
459
469
470
+ private static async Task < AccessToken > GetTokenAsync ( TokenCredentialKey tokenCredentialKey , string secret ,
471
+ TokenRequestContext tokenRequestContext , CancellationToken cancellationToken )
472
+ {
473
+ if ( ! s_tokenCredentialMap . TryGetValue ( tokenCredentialKey , out TokenCredentialData tokenCredentialInstance ) )
474
+ {
475
+ await s_tokenCredentialMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
476
+ try
477
+ {
478
+ // Double-check in case another thread added it while we waited for the semaphore
479
+ if ( ! s_tokenCredentialMap . TryGetValue ( tokenCredentialKey , out tokenCredentialInstance ) )
480
+ {
481
+ tokenCredentialInstance = CreateTokenCredentialInstance ( tokenCredentialKey , secret ) ;
482
+ s_tokenCredentialMap . TryAdd ( tokenCredentialKey , tokenCredentialInstance ) ;
483
+ }
484
+ }
485
+ finally
486
+ {
487
+ s_tokenCredentialMapModifierSemaphore . Release ( ) ;
488
+ }
489
+ }
490
+
491
+ if ( ! AreEqual ( tokenCredentialInstance . _secretHash , GetHash ( secret ) ) )
492
+ {
493
+ // If the secret hash has changed, we need to remove the old token credential instance and create a new one.
494
+ await s_tokenCredentialMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
495
+ try
496
+ {
497
+ s_tokenCredentialMap . TryRemove ( tokenCredentialKey , out _ ) ;
498
+ tokenCredentialInstance = CreateTokenCredentialInstance ( tokenCredentialKey , secret ) ;
499
+ s_tokenCredentialMap . TryAdd ( tokenCredentialKey , tokenCredentialInstance ) ;
500
+ }
501
+ finally
502
+ {
503
+ s_tokenCredentialMapModifierSemaphore . Release ( ) ;
504
+ }
505
+ }
506
+
507
+ return await tokenCredentialInstance . _tokenCredential . GetTokenAsync ( tokenRequestContext , cancellationToken ) ;
508
+ }
509
+
460
510
private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
461
511
{
462
512
return parameters . Authority + "+" + parameters . UserId ;
@@ -470,6 +520,24 @@ private static byte[] GetHash(string input)
470
520
return hashedBytes ;
471
521
}
472
522
523
+ private static bool AreEqual ( byte [ ] a1 , byte [ ] a2 )
524
+ {
525
+ if ( ReferenceEquals ( a1 , a2 ) )
526
+ {
527
+ return true ;
528
+ }
529
+ else if ( a1 is null || a2 is null )
530
+ {
531
+ return false ;
532
+ }
533
+ else if ( a1 . Length != a2 . Length )
534
+ {
535
+ return false ;
536
+ }
537
+
538
+ return a1 . AsSpan ( ) . SequenceEqual ( a2 . AsSpan ( ) ) ;
539
+ }
540
+
473
541
private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
474
542
{
475
543
IPublicClientApplication publicClientApplication ;
@@ -513,6 +581,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ
513
581
return publicClientApplication ;
514
582
}
515
583
584
+ private static TokenCredentialData CreateTokenCredentialInstance ( TokenCredentialKey tokenCredentialKey , string secret )
585
+ {
586
+ if ( tokenCredentialKey . _tokenCredentialType == typeof ( DefaultAzureCredential ) )
587
+ {
588
+ DefaultAzureCredentialOptions defaultAzureCredentialOptions = new ( )
589
+ {
590
+ AuthorityHost = new Uri ( tokenCredentialKey . _authority ) ,
591
+ SharedTokenCacheTenantId = tokenCredentialKey . _audience ,
592
+ VisualStudioCodeTenantId = tokenCredentialKey . _audience ,
593
+ VisualStudioTenantId = tokenCredentialKey . _audience ,
594
+ ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
595
+ } ;
596
+
597
+ // Optionally set clientId when available
598
+ if ( tokenCredentialKey . _clientId is not null )
599
+ {
600
+ defaultAzureCredentialOptions . ManagedIdentityClientId = tokenCredentialKey . _clientId ;
601
+ defaultAzureCredentialOptions . SharedTokenCacheUsername = tokenCredentialKey . _clientId ;
602
+ defaultAzureCredentialOptions . WorkloadIdentityClientId = tokenCredentialKey . _clientId ;
603
+ }
604
+
605
+ return new TokenCredentialData ( new DefaultAzureCredential ( defaultAzureCredentialOptions ) , GetHash ( secret ) ) ;
606
+ }
607
+
608
+ TokenCredentialOptions tokenCredentialOptions = new ( ) { AuthorityHost = new Uri ( tokenCredentialKey . _authority ) } ;
609
+
610
+ if ( tokenCredentialKey . _tokenCredentialType == typeof ( ManagedIdentityCredential ) )
611
+ {
612
+ return new TokenCredentialData ( new ManagedIdentityCredential ( tokenCredentialKey . _clientId , tokenCredentialOptions ) , GetHash ( secret ) ) ;
613
+ }
614
+ else if ( tokenCredentialKey . _tokenCredentialType == typeof ( ClientSecretCredential ) )
615
+ {
616
+ return new TokenCredentialData ( new ClientSecretCredential ( tokenCredentialKey . _audience , tokenCredentialKey . _clientId , secret , tokenCredentialOptions ) , GetHash ( secret ) ) ;
617
+ }
618
+ else if ( tokenCredentialKey . _tokenCredentialType == typeof ( WorkloadIdentityCredential ) )
619
+ {
620
+ // The WorkloadIdentityCredentialOptions object initialization populates its instance members
621
+ // from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE,
622
+ // and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id.
623
+ WorkloadIdentityCredentialOptions options = new ( ) { AuthorityHost = new Uri ( tokenCredentialKey . _authority ) } ;
624
+
625
+ if ( tokenCredentialKey . _clientId is not null )
626
+ {
627
+ options . ClientId = tokenCredentialKey . _clientId ;
628
+ }
629
+
630
+ return new TokenCredentialData ( new WorkloadIdentityCredential ( options ) , GetHash ( secret ) ) ;
631
+ }
632
+
633
+ // This should never be reached, but if it is, throw an exception that will be noticed during development
634
+ throw new ArgumentException ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
635
+ }
636
+
516
637
internal class PublicClientAppKey
517
638
{
518
639
public readonly string _authority ;
@@ -572,5 +693,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app
572
693
#endif
573
694
) . GetHashCode ( ) ;
574
695
}
696
+
697
+ internal class TokenCredentialData
698
+ {
699
+ public TokenCredential _tokenCredential ;
700
+ public byte [ ] _secretHash ;
701
+
702
+ public TokenCredentialData ( TokenCredential tokenCredential , byte [ ] secretHash )
703
+ {
704
+ _tokenCredential = tokenCredential ;
705
+ _secretHash = secretHash ;
706
+ }
707
+ }
708
+
709
+ internal class TokenCredentialKey
710
+ {
711
+ public readonly Type _tokenCredentialType ;
712
+ public readonly string _authority ;
713
+ public readonly string _scope ;
714
+ public readonly string _audience ;
715
+ public readonly string _clientId ;
716
+
717
+ public TokenCredentialKey ( Type tokenCredentialType , string authority , string scope , string audience , string clientId )
718
+ {
719
+ _tokenCredentialType = tokenCredentialType ;
720
+ _authority = authority ;
721
+ _scope = scope ;
722
+ _audience = audience ;
723
+ _clientId = clientId ;
724
+ }
725
+
726
+ public override bool Equals ( object obj )
727
+ {
728
+ if ( obj != null && obj is TokenCredentialKey tcKey )
729
+ {
730
+ return string . CompareOrdinal ( nameof ( _tokenCredentialType ) , nameof ( tcKey . _tokenCredentialType ) ) == 0
731
+ && string . CompareOrdinal ( _authority , tcKey . _authority ) == 0
732
+ && string . CompareOrdinal ( _scope , tcKey . _scope ) == 0
733
+ && string . CompareOrdinal ( _audience , tcKey . _audience ) == 0
734
+ && string . CompareOrdinal ( _clientId , tcKey . _clientId ) == 0
735
+ ;
736
+ }
737
+ return false ;
738
+ }
739
+
740
+ public override int GetHashCode ( ) => Tuple . Create ( _tokenCredentialType , _authority , _scope , _audience , _clientId ) . GetHashCode ( ) ;
741
+ }
742
+
575
743
}
576
744
}
0 commit comments