Skip to content

Commit 8435d00

Browse files
Fix encryption key cache design for AKV provider (#3464)
* Fix AzureSqlKeyCryptographer to be pure sync * Fix Symmetric key cache design + Key vault cache * Fix tests * Add back multi-user cache scenario * Add cache lock in global CEK cache to prevent concurrency issues. * Touch-ups * More changes to streamline getKey calls * More updates * - Addressed review comments. * - Fixing tests that expect certain localized strings that have changed. --------- Co-authored-by: Paul Medynski <31868385+paulmedynski@users.noreply.github.com>
1 parent 951da45 commit 8435d00

File tree

10 files changed

+199
-132
lines changed

10 files changed

+199
-132
lines changed

src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/AzureSqlKeyCryptographer.cs

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
using Azure.Security.KeyVault.Keys.Cryptography;
88
using System;
99
using System.Collections.Concurrent;
10-
using System.Threading.Tasks;
10+
using System.Threading;
1111
using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm;
1212

1313
namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
1414
{
15-
internal class AzureSqlKeyCryptographer
15+
internal sealed class AzureSqlKeyCryptographer : IDisposable
1616
{
1717
/// <summary>
1818
/// TokenCredential to be used with the KeyClient
@@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer
2525
private readonly ConcurrentDictionary<Uri, KeyClient> _keyClientDictionary = new();
2626

2727
/// <summary>
28-
/// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI).
29-
/// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
30-
/// key into the key dictionary.
28+
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
3129
/// </summary>
32-
private readonly ConcurrentDictionary<string, Task<Azure.Response<KeyVaultKey>>> _keyFetchTaskDictionary = new();
30+
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();
3331

3432
/// <summary>
35-
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
33+
/// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys.
3634
/// </summary>
37-
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();
35+
private SemaphoreSlim _keyDictionarySemaphore = new(1, 1);
3836

3937
/// <summary>
4038
/// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI).
@@ -50,20 +48,44 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential)
5048
TokenCredential = tokenCredential;
5149
}
5250

51+
/// <summary>
52+
/// Disposes the SemaphoreSlim used for thread safety.
53+
/// </summary>
54+
public void Dispose()
55+
{
56+
_keyDictionarySemaphore.Dispose();
57+
}
58+
5359
/// <summary>
5460
/// Adds the key, specified by the Key Identifier URI, to the cache.
61+
/// Validates the key type and fetches the key from Azure Key Vault if it is not already cached.
5562
/// </summary>
5663
/// <param name="keyIdentifierUri"></param>
5764
internal void AddKey(string keyIdentifierUri)
5865
{
59-
if (TheKeyHasNotBeenCached(keyIdentifierUri))
66+
// Allow only one thread to proceed to ensure thread safety
67+
// as we will need to fetch key information from Azure Key Vault if the key is not found in cache.
68+
_keyDictionarySemaphore.Wait();
69+
70+
try
6071
{
61-
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);
62-
CreateKeyClient(vaultUri);
63-
FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri);
64-
}
72+
if (!_keyDictionary.ContainsKey(keyIdentifierUri))
73+
{
74+
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);
75+
76+
// Fetch the KeyClient for the Key vault URI.
77+
KeyClient keyClient = GetOrCreateKeyClient(vaultUri);
78+
79+
// Fetch the key from Azure Key Vault.
80+
KeyVaultKey key = FetchKeyFromKeyVault(keyClient, keyName, keyVersion);
6581

66-
bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k);
82+
_keyDictionary.AddOrUpdate(keyIdentifierUri, key, (k, v) => key);
83+
}
84+
}
85+
finally
86+
{
87+
_keyDictionarySemaphore.Release();
88+
}
6789
}
6890

6991
/// <summary>
@@ -75,18 +97,12 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
7597
{
7698
if (_keyDictionary.TryGetValue(keyIdentifierUri, out KeyVaultKey key))
7799
{
78-
AKVEventSource.Log.TryTraceEvent("Fetched master key from cache");
100+
AKVEventSource.Log.TryTraceEvent("Fetched key name={0} from cache", key.Name);
79101
return key;
80102
}
81103

82-
if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task<Azure.Response<KeyVaultKey>> task))
83-
{
84-
AKVEventSource.Log.TryTraceEvent("New Master key fetched.");
85-
return Task.Run(() => task).GetAwaiter().GetResult();
86-
}
87-
88104
// Not a public exception - not likely to occur.
89-
AKVEventSource.Log.TryTraceEvent("Master key not found.");
105+
AKVEventSource.Log.TryTraceEvent("Key not found; URI={0}", keyIdentifierUri);
90106
throw ADP.MasterKeyNotFound(keyIdentifierUri);
91107
}
92108

@@ -95,10 +111,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
95111
/// </summary>
96112
/// <param name="keyIdentifierUri">The key vault key identifier URI</param>
97113
/// <returns></returns>
98-
internal int GetKeySize(string keyIdentifierUri)
99-
{
100-
return GetKey(keyIdentifierUri).Key.N.Length;
101-
}
114+
internal int GetKeySize(string keyIdentifierUri) => GetKey(keyIdentifierUri).Key.N.Length;
102115

103116
/// <summary>
104117
/// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
@@ -142,49 +155,58 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)
142155

143156
CryptographyClient cryptographyClient = new(GetKey(keyIdentifierUri).Id, TokenCredential);
144157
_cryptoClientDictionary.TryAdd(keyIdentifierUri, cryptographyClient);
145-
146158
return cryptographyClient;
147159
}
148160

149161
/// <summary>
150-
///
162+
/// Fetches the column encryption key from the Azure Key Vault.
151163
/// </summary>
152-
/// <param name="vaultUri">The Azure Key Vault URI</param>
164+
/// <param name="keyClient">The KeyClient instance</param>
153165
/// <param name="keyName">The name of the Azure Key Vault key</param>
154166
/// <param name="keyVersion">The version of the Azure Key Vault key</param>
155-
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
156-
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri)
167+
private KeyVaultKey FetchKeyFromKeyVault(KeyClient keyClient, string keyName, string keyVersion)
157168
{
158-
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
159-
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask);
169+
AKVEventSource.Log.TryTraceEvent("Fetching key name={0}", keyName);
160170

161-
fetchKeyTask
162-
.ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult()))
163-
.ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult()));
171+
Azure.Response<KeyVaultKey> keyResponse = keyClient?.GetKey(keyName, keyVersion);
164172

165-
Task.Run(() => fetchKeyTask);
173+
// Handle the case where the key response is null or contains an error
174+
// This can happen if the key does not exist or if there is an issue with the KeyClient.
175+
// In such cases, we log the error and throw an exception.
176+
if (keyResponse == null || keyResponse.Value == null || keyResponse.GetRawResponse().IsError)
177+
{
178+
AKVEventSource.Log.TryTraceEvent("Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}", keyName, keyVersion);
179+
if (keyResponse?.GetRawResponse() is Azure.Response response)
180+
{
181+
AKVEventSource.Log.TryTraceEvent("Response status {0} : {1}", response.Status, response.ReasonPhrase);
182+
}
183+
throw ADP.GetKeyFailed(keyName);
184+
}
185+
186+
KeyVaultKey key = keyResponse.Value;
187+
188+
// Validate that the key is of type RSA
189+
key = ValidateRsaKey(key);
190+
return key;
166191
}
167192

168193
/// <summary>
169-
/// Looks up the KeyClient object by it's URI and then fetches the key by name.
194+
/// Gets or creates a KeyClient for the specified Azure Key Vault URI.
170195
/// </summary>
171-
/// <param name="vaultUri">The Azure Key Vault URI</param>
172-
/// <param name="keyName">Then name of the key</param>
173-
/// <param name="keyVersion">Then version of the key</param>
196+
/// <param name="vaultUri">Key Identifier URL</param>
174197
/// <returns></returns>
175-
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
198+
private KeyClient GetOrCreateKeyClient(Uri vaultUri)
176199
{
177-
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient);
178-
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName);
179-
return keyClient?.GetKeyAsync(keyName, keyVersion);
200+
return _keyClientDictionary.GetOrAdd(
201+
vaultUri, (_) => new KeyClient(vaultUri, TokenCredential));
180202
}
181203

182204
/// <summary>
183205
/// Validates that a key is of type RSA
184206
/// </summary>
185207
/// <param name="key"></param>
186208
/// <returns></returns>
187-
private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
209+
private static KeyVaultKey ValidateRsaKey(KeyVaultKey key)
188210
{
189211
if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm)
190212
{
@@ -195,26 +217,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
195217
return key;
196218
}
197219

198-
/// <summary>
199-
/// Instantiates and adds a KeyClient to the KeyClient dictionary
200-
/// </summary>
201-
/// <param name="vaultUri">The Azure Key Vault URI</param>
202-
private void CreateKeyClient(Uri vaultUri)
203-
{
204-
if (!_keyClientDictionary.ContainsKey(vaultUri))
205-
{
206-
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential));
207-
}
208-
}
209-
210220
/// <summary>
211221
/// Validates and parses the Azure Key Vault URI and key name.
212222
/// </summary>
213223
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
214224
/// <param name="vaultUri">The Azure Key Vault URI</param>
215225
/// <param name="masterKeyName">The name of the key</param>
216226
/// <param name="masterKeyVersion">The version of the key</param>
217-
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
227+
private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
218228
{
219229
Uri masterKeyPathUri = new(masterKeyPath);
220230
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority));

src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/LocalCache.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Microsoft.Extensions.Caching.Memory;
65
using System;
6+
using Microsoft.Extensions.Caching.Memory;
77
using static System.Math;
88

99
namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
@@ -92,6 +92,7 @@ internal TValue GetOrCreate(TKey key, Func<TValue> createItem)
9292

9393
/// <summary>
9494
/// Determines whether the <see cref="LocalCache{TKey, TValue}">LocalCache</see> contains the specified key.
95+
/// Used in unit tests to verify that the cache contains the expected entries.
9596
/// </summary>
9697
/// <param name="key"></param>
9798
/// <returns></returns>

src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Text;
7+
using System.Threading;
78
using Azure.Core;
89
using Azure.Security.KeyVault.Keys.Cryptography;
910
using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator;
@@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt
5556

5657
private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep;
5758

59+
private SemaphoreSlim _cacheSemaphore = new(1, 1);
60+
5861
/// <summary>
5962
/// List of Trusted Endpoints
6063
///
@@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt
6972
/// <summary>
7073
/// A cache for storing the results of signature verification of column master key metadata.
7174
/// </summary>
72-
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
75+
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
7376
new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) };
7477

7578
/// <summary>
@@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey()
230233
// Get ciphertext
231234
byte[] cipherText = new byte[cipherTextLength];
232235
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);
233-
236+
234237
currentIndex += cipherTextLength;
235238

236239
// Get signature
@@ -394,17 +397,10 @@ private byte[] CompileMasterKeyMetadata(string masterKeyPath, bool allowEnclaveC
394397
/// <param name="source">An array of bytes to convert.</param>
395398
/// <returns>A string of hexadecimal characters</returns>
396399
/// <remarks>
397-
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in value; for example, "0x7F2C4A00".
400+
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in source; for example, "0x7F2C4A00".
398401
/// </remarks>
399402
private string ToHexString(byte[] source)
400-
{
401-
if (source is null)
402-
{
403-
return null;
404-
}
405-
406-
return "0x" + BitConverter.ToString(source).Replace("-", "");
407-
}
403+
=> source is null ? null : "0x" + BitConverter.ToString(source).Replace("-", "");
408404

409405
/// <summary>
410406
/// Returns the cached decrypted column encryption key, or unwraps the encrypted column encryption key if not present.
@@ -415,8 +411,21 @@ private string ToHexString(byte[] source)
415411
/// <remarks>
416412
///
417413
/// </remarks>
418-
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
419-
=> _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
414+
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
415+
{
416+
// Allow only one thread to access the cache at a time.
417+
_cacheSemaphore.Wait();
418+
419+
try
420+
{
421+
return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
422+
}
423+
finally
424+
{
425+
// Release the semaphore to allow other threads to access the cache.
426+
_cacheSemaphore.Release();
427+
}
428+
}
420429

421430
/// <summary>
422431
/// Returns the cached signature verification result, or proceeds to verify if not present.

src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.Designer.cs

Lines changed: 25 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.resx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,16 @@
118118
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
119119
</resheader>
120120
<data name="NullOrWhitespaceForEach" xml:space="preserve">
121-
<value>One or more of the elements in {0} are null or empty or consist of only whitespace.</value>
121+
<value>One or more of the elements in '{0}' are null or empty or consist of only whitespace.</value>
122122
</data>
123123
<data name="CipherTextLengthMismatch" xml:space="preserve">
124124
<value>CipherText length does not match the RSA key size.</value>
125125
</data>
126126
<data name="EmptyArgumentInternal" xml:space="preserve">
127-
<value>Internal error. Empty {0} specified.</value>
127+
<value>Internal error. Empty '{0}' specified.</value>
128+
</data>
129+
<data name="GetKeyFailed" xml:space="preserve">
130+
<value>Failed to fetch key from Azure Key Vault. Key: {0}.</value>
128131
</data>
129132
<data name="MasterKeyNotFound" xml:space="preserve">
130133
<value>The key with identifier '{0}' was not found.</value>

0 commit comments

Comments
 (0)