7
7
using Azure . Security . KeyVault . Keys . Cryptography ;
8
8
using System ;
9
9
using System . Collections . Concurrent ;
10
- using System . Threading . Tasks ;
10
+ using System . Threading ;
11
11
using static Azure . Security . KeyVault . Keys . Cryptography . SignatureAlgorithm ;
12
12
13
13
namespace Microsoft . Data . SqlClient . AlwaysEncrypted . AzureKeyVaultProvider
14
14
{
15
- internal class AzureSqlKeyCryptographer
15
+ internal sealed class AzureSqlKeyCryptographer : IDisposable
16
16
{
17
17
/// <summary>
18
18
/// TokenCredential to be used with the KeyClient
@@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer
25
25
private readonly ConcurrentDictionary < Uri , KeyClient > _keyClientDictionary = new ( ) ;
26
26
27
27
/// <summary>
28
- /// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI).
29
- /// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
30
- /// key into the key dictionary.
28
+ /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
31
29
/// </summary>
32
- private readonly ConcurrentDictionary < string , Task < Azure . Response < KeyVaultKey > > > _keyFetchTaskDictionary = new ( ) ;
30
+ private readonly ConcurrentDictionary < string , KeyVaultKey > _keyDictionary = new ( ) ;
33
31
34
32
/// <summary>
35
- /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI) .
33
+ /// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys .
36
34
/// </summary>
37
- private readonly ConcurrentDictionary < string , KeyVaultKey > _keyDictionary = new ( ) ;
35
+ private SemaphoreSlim _keyDictionarySemaphore = new ( 1 , 1 ) ;
38
36
39
37
/// <summary>
40
38
/// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI).
@@ -50,20 +48,44 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential)
50
48
TokenCredential = tokenCredential ;
51
49
}
52
50
51
+ /// <summary>
52
+ /// Disposes the SemaphoreSlim used for thread safety.
53
+ /// </summary>
54
+ public void Dispose ( )
55
+ {
56
+ _keyDictionarySemaphore . Dispose ( ) ;
57
+ }
58
+
53
59
/// <summary>
54
60
/// Adds the key, specified by the Key Identifier URI, to the cache.
61
+ /// Validates the key type and fetches the key from Azure Key Vault if it is not already cached.
55
62
/// </summary>
56
63
/// <param name="keyIdentifierUri"></param>
57
64
internal void AddKey ( string keyIdentifierUri )
58
65
{
59
- if ( TheKeyHasNotBeenCached ( keyIdentifierUri ) )
66
+ // Allow only one thread to proceed to ensure thread safety
67
+ // as we will need to fetch key information from Azure Key Vault if the key is not found in cache.
68
+ _keyDictionarySemaphore . Wait ( ) ;
69
+
70
+ try
60
71
{
61
- ParseAKVPath ( keyIdentifierUri , out Uri vaultUri , out string keyName , out string keyVersion ) ;
62
- CreateKeyClient ( vaultUri ) ;
63
- FetchKey ( vaultUri , keyName , keyVersion , keyIdentifierUri ) ;
64
- }
72
+ if ( ! _keyDictionary . ContainsKey ( keyIdentifierUri ) )
73
+ {
74
+ ParseAKVPath ( keyIdentifierUri , out Uri vaultUri , out string keyName , out string keyVersion ) ;
75
+
76
+ // Fetch the KeyClient for the Key vault URI.
77
+ KeyClient keyClient = GetOrCreateKeyClient ( vaultUri ) ;
78
+
79
+ // Fetch the key from Azure Key Vault.
80
+ KeyVaultKey key = FetchKeyFromKeyVault ( keyClient , keyName , keyVersion ) ;
65
81
66
- bool TheKeyHasNotBeenCached ( string k ) => ! _keyDictionary . ContainsKey ( k ) && ! _keyFetchTaskDictionary . ContainsKey ( k ) ;
82
+ _keyDictionary . AddOrUpdate ( keyIdentifierUri , key , ( k , v ) => key ) ;
83
+ }
84
+ }
85
+ finally
86
+ {
87
+ _keyDictionarySemaphore . Release ( ) ;
88
+ }
67
89
}
68
90
69
91
/// <summary>
@@ -75,18 +97,12 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
75
97
{
76
98
if ( _keyDictionary . TryGetValue ( keyIdentifierUri , out KeyVaultKey key ) )
77
99
{
78
- AKVEventSource . Log . TryTraceEvent ( "Fetched master key from cache" ) ;
100
+ AKVEventSource . Log . TryTraceEvent ( "Fetched key name={0} from cache" , key . Name ) ;
79
101
return key ;
80
102
}
81
103
82
- if ( _keyFetchTaskDictionary . TryGetValue ( keyIdentifierUri , out Task < Azure . Response < KeyVaultKey > > task ) )
83
- {
84
- AKVEventSource . Log . TryTraceEvent ( "New Master key fetched." ) ;
85
- return Task . Run ( ( ) => task ) . GetAwaiter ( ) . GetResult ( ) ;
86
- }
87
-
88
104
// Not a public exception - not likely to occur.
89
- AKVEventSource . Log . TryTraceEvent ( "Master key not found." ) ;
105
+ AKVEventSource . Log . TryTraceEvent ( "Key not found; URI={0}" , keyIdentifierUri ) ;
90
106
throw ADP . MasterKeyNotFound ( keyIdentifierUri ) ;
91
107
}
92
108
@@ -95,10 +111,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
95
111
/// </summary>
96
112
/// <param name="keyIdentifierUri">The key vault key identifier URI</param>
97
113
/// <returns></returns>
98
- internal int GetKeySize ( string keyIdentifierUri )
99
- {
100
- return GetKey ( keyIdentifierUri ) . Key . N . Length ;
101
- }
114
+ internal int GetKeySize ( string keyIdentifierUri ) => GetKey ( keyIdentifierUri ) . Key . N . Length ;
102
115
103
116
/// <summary>
104
117
/// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
@@ -142,49 +155,58 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)
142
155
143
156
CryptographyClient cryptographyClient = new ( GetKey ( keyIdentifierUri ) . Id , TokenCredential ) ;
144
157
_cryptoClientDictionary . TryAdd ( keyIdentifierUri , cryptographyClient ) ;
145
-
146
158
return cryptographyClient ;
147
159
}
148
160
149
161
/// <summary>
150
- ///
162
+ /// Fetches the column encryption key from the Azure Key Vault.
151
163
/// </summary>
152
- /// <param name="vaultUri ">The Azure Key Vault URI </param>
164
+ /// <param name="keyClient ">The KeyClient instance </param>
153
165
/// <param name="keyName">The name of the Azure Key Vault key</param>
154
166
/// <param name="keyVersion">The version of the Azure Key Vault key</param>
155
- /// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
156
- private void FetchKey ( Uri vaultUri , string keyName , string keyVersion , string keyResourceUri )
167
+ private KeyVaultKey FetchKeyFromKeyVault ( KeyClient keyClient , string keyName , string keyVersion )
157
168
{
158
- Task < Azure . Response < KeyVaultKey > > fetchKeyTask = FetchKeyFromKeyVault ( vaultUri , keyName , keyVersion ) ;
159
- _keyFetchTaskDictionary . AddOrUpdate ( keyResourceUri , fetchKeyTask , ( k , v ) => fetchKeyTask ) ;
169
+ AKVEventSource . Log . TryTraceEvent ( "Fetching key name={0}" , keyName ) ;
160
170
161
- fetchKeyTask
162
- . ContinueWith ( k => ValidateRsaKey ( k . GetAwaiter ( ) . GetResult ( ) ) )
163
- . ContinueWith ( k => _keyDictionary . AddOrUpdate ( keyResourceUri , k . GetAwaiter ( ) . GetResult ( ) , ( key , v ) => k . GetAwaiter ( ) . GetResult ( ) ) ) ;
171
+ Azure . Response < KeyVaultKey > keyResponse = keyClient ? . GetKey ( keyName , keyVersion ) ;
164
172
165
- Task . Run ( ( ) => fetchKeyTask ) ;
173
+ // Handle the case where the key response is null or contains an error
174
+ // This can happen if the key does not exist or if there is an issue with the KeyClient.
175
+ // In such cases, we log the error and throw an exception.
176
+ if ( keyResponse == null || keyResponse . Value == null || keyResponse . GetRawResponse ( ) . IsError )
177
+ {
178
+ AKVEventSource . Log . TryTraceEvent ( "Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}" , keyName , keyVersion ) ;
179
+ if ( keyResponse ? . GetRawResponse ( ) is Azure . Response response )
180
+ {
181
+ AKVEventSource . Log . TryTraceEvent ( "Response status {0} : {1}" , response . Status , response . ReasonPhrase ) ;
182
+ }
183
+ throw ADP . GetKeyFailed ( keyName ) ;
184
+ }
185
+
186
+ KeyVaultKey key = keyResponse . Value ;
187
+
188
+ // Validate that the key is of type RSA
189
+ key = ValidateRsaKey ( key ) ;
190
+ return key ;
166
191
}
167
192
168
193
/// <summary>
169
- /// Looks up the KeyClient object by it's URI and then fetches the key by name .
194
+ /// Gets or creates a KeyClient for the specified Azure Key Vault URI .
170
195
/// </summary>
171
- /// <param name="vaultUri">The Azure Key Vault URI</param>
172
- /// <param name="keyName">Then name of the key</param>
173
- /// <param name="keyVersion">Then version of the key</param>
196
+ /// <param name="vaultUri">Key Identifier URL</param>
174
197
/// <returns></returns>
175
- private Task < Azure . Response < KeyVaultKey > > FetchKeyFromKeyVault ( Uri vaultUri , string keyName , string keyVersion )
198
+ private KeyClient GetOrCreateKeyClient ( Uri vaultUri )
176
199
{
177
- _keyClientDictionary . TryGetValue ( vaultUri , out KeyClient keyClient ) ;
178
- AKVEventSource . Log . TryTraceEvent ( "Fetching requested master key: {0}" , keyName ) ;
179
- return keyClient ? . GetKeyAsync ( keyName , keyVersion ) ;
200
+ return _keyClientDictionary . GetOrAdd (
201
+ vaultUri , ( _ ) => new KeyClient ( vaultUri , TokenCredential ) ) ;
180
202
}
181
203
182
204
/// <summary>
183
205
/// Validates that a key is of type RSA
184
206
/// </summary>
185
207
/// <param name="key"></param>
186
208
/// <returns></returns>
187
- private KeyVaultKey ValidateRsaKey ( KeyVaultKey key )
209
+ private static KeyVaultKey ValidateRsaKey ( KeyVaultKey key )
188
210
{
189
211
if ( key . KeyType != KeyType . Rsa && key . KeyType != KeyType . RsaHsm )
190
212
{
@@ -195,26 +217,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
195
217
return key ;
196
218
}
197
219
198
- /// <summary>
199
- /// Instantiates and adds a KeyClient to the KeyClient dictionary
200
- /// </summary>
201
- /// <param name="vaultUri">The Azure Key Vault URI</param>
202
- private void CreateKeyClient ( Uri vaultUri )
203
- {
204
- if ( ! _keyClientDictionary . ContainsKey ( vaultUri ) )
205
- {
206
- _keyClientDictionary . TryAdd ( vaultUri , new KeyClient ( vaultUri , TokenCredential ) ) ;
207
- }
208
- }
209
-
210
220
/// <summary>
211
221
/// Validates and parses the Azure Key Vault URI and key name.
212
222
/// </summary>
213
223
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
214
224
/// <param name="vaultUri">The Azure Key Vault URI</param>
215
225
/// <param name="masterKeyName">The name of the key</param>
216
226
/// <param name="masterKeyVersion">The version of the key</param>
217
- private void ParseAKVPath ( string masterKeyPath , out Uri vaultUri , out string masterKeyName , out string masterKeyVersion )
227
+ private static void ParseAKVPath ( string masterKeyPath , out Uri vaultUri , out string masterKeyName , out string masterKeyVersion )
218
228
{
219
229
Uri masterKeyPathUri = new ( masterKeyPath ) ;
220
230
vaultUri = new Uri ( masterKeyPathUri . GetLeftPart ( UriPartial . Authority ) ) ;
0 commit comments