Skip to content

Commit 4f6c051

Browse files
committed
Expose SqlAuthenticationParameters on SSPIContextProvider
This change updates the SSPI context provider to surface information to implementers via SqlAuthenticationParameters. As part of this change,the internal storage of SPN is changed from byte[] to string values. Majority of implementations need the string value anyway so it makes things simpler for book keeping.
1 parent e52f1c3 commit 4f6c051

16 files changed

+291
-116
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers;
7+
using System.Diagnostics;
68
using System.Runtime.InteropServices;
79
using System.Text;
810
using Microsoft.Data.Common;
@@ -398,7 +400,7 @@ internal static unsafe uint SNIOpenSyncEx(
398400
ConsumerInfo consumerInfo,
399401
string constring,
400402
ref IntPtr pConn,
401-
byte[] spnBuffer,
403+
ref string spn,
402404
byte[] instanceName,
403405
bool fOverrideCache,
404406
bool fSync,
@@ -436,13 +438,60 @@ internal static unsafe uint SNIOpenSyncEx(
436438
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
437439
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;
438440

439-
if (spnBuffer != null)
441+
if (spn != null)
440442
{
441-
fixed (byte* pin_spnBuffer = &spnBuffer[0])
443+
// An empty string implies we need to find the SPN so we supply a buffer for the max size
444+
if (spn.Length == 0)
442445
{
443-
clientConsumerInfo.szSPN = pin_spnBuffer;
444-
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
445-
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
446+
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
447+
array.AsSpan().Clear();
448+
449+
try
450+
{
451+
fixed (byte* pin_spnBuffer = array)
452+
{
453+
clientConsumerInfo.szSPN = pin_spnBuffer;
454+
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;
455+
456+
var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
457+
458+
if (result == 0)
459+
{
460+
spn = Encoding.Unicode.GetString(array);
461+
}
462+
463+
return result;
464+
}
465+
}
466+
finally
467+
{
468+
ArrayPool<byte>.Shared.Return(array);
469+
}
470+
}
471+
472+
// We have a value of the SPN, so we marshal that and send it to the native layer
473+
else
474+
{
475+
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
476+
var writer = SqlObjectPools.BufferWriter.Rent();
477+
478+
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
479+
Encoding.Unicode.GetBytes(spn, writer);
480+
Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
481+
482+
try
483+
{
484+
fixed (byte* pin_spnBuffer = writer.WrittenSpan)
485+
{
486+
clientConsumerInfo.szSPN = pin_spnBuffer;
487+
clientConsumerInfo.cchSPN = (uint)writer.WrittenCount;
488+
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
489+
}
490+
}
491+
finally
492+
{
493+
SqlObjectPools.BufferWriter.Return(writer);
494+
}
446495
}
447496
}
448497
else
@@ -471,26 +520,37 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
471520
}
472521
}
473522

474-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
523+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, string serverUserName)
475524
{
476525
sendLength = (uint)outBuff.Length;
477526

478-
fixed (byte* pin_serverUserName = &serverUserName[0])
479-
fixed (byte* pInBuff = inBuff)
480-
fixed (byte* pOutBuff = outBuff)
527+
var serverWriter = SqlObjectPools.BufferWriter.Rent();
528+
529+
try
530+
{
531+
Encoding.Unicode.GetBytes(serverUserName, serverWriter);
532+
533+
fixed (byte* pin_serverUserName = serverWriter.WrittenSpan)
534+
fixed (byte* pInBuff = inBuff)
535+
fixed (byte* pOutBuff = outBuff)
536+
{
537+
bool local_fDone;
538+
return SNISecGenClientContextWrapper(
539+
pConnectionObject,
540+
pInBuff,
541+
(uint)inBuff.Length,
542+
pOutBuff,
543+
ref sendLength,
544+
out local_fDone,
545+
pin_serverUserName,
546+
(uint)serverWriter.WrittenCount,
547+
null,
548+
null);
549+
}
550+
}
551+
finally
481552
{
482-
bool local_fDone;
483-
return SNISecGenClientContextWrapper(
484-
pConnectionObject,
485-
pInBuff,
486-
(uint)inBuff.Length,
487-
pOutBuff,
488-
ref sendLength,
489-
out local_fDone,
490-
pin_serverUserName,
491-
(uint)serverUserName.Length,
492-
null,
493-
null);
553+
SqlObjectPools.BufferWriter.Return(serverWriter);
494554
}
495555
}
496556

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Buffers;
7+
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
@@ -33,9 +34,9 @@ internal class SNIProxy
3334
/// <param name="sspiClientContextStatus">SSPI client context status</param>
3435
/// <param name="receivedBuff">Receive buffer</param>
3536
/// <param name="sendWriter">Writer for send buffer</param>
36-
/// <param name="serverName">Service Principal Name buffer</param>
37+
/// <param name="serverNames">Service Principal Name buffer</param>
3738
/// <returns>SNI error code</returns>
38-
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
39+
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
3940
{
4041
// TODO: this should use ReadOnlyMemory all the way through
4142
byte[] array = null;
@@ -46,10 +47,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
4647
receivedBuff.CopyTo(array);
4748
}
4849

49-
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
50+
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverNames);
5051
}
5152

52-
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
53+
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, string[] serverSPNs)
5354
{
5455
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
5556
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -81,11 +82,6 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
8182
| ContextFlagsPal.Delegate
8283
| ContextFlagsPal.MutualAuth;
8384

84-
string[] serverSPNs = new string[serverName.Length];
85-
for (int i = 0; i < serverName.Length; i++)
86-
{
87-
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
88-
}
8985
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
9086
credentialsHandle,
9187
ref securityContext,
@@ -164,7 +160,7 @@ internal static SNIHandle CreateConnectionHandle(
164160
string fullServerName,
165161
TimeoutTimer timeout,
166162
out byte[] instanceName,
167-
ref byte[][] spnBuffer,
163+
ref string[] spnBuffer,
168164
string serverSPN,
169165
bool flushCache,
170166
bool async,
@@ -228,12 +224,12 @@ internal static SNIHandle CreateConnectionHandle(
228224
return sniHandle;
229225
}
230226

231-
private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
227+
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
232228
{
233229
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
234230
if (!string.IsNullOrWhiteSpace(serverSPN))
235231
{
236-
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
232+
return new[] { serverSPN };
237233
}
238234

239235
string hostName = dataSource.ServerName;
@@ -251,7 +247,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
251247
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
252248
}
253249

254-
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
250+
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
255251
{
256252
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
257253
IPHostEntry hostEntry = null;
@@ -282,12 +278,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
282278
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
283279
// Set both SPNs with and without Port as Port is optional for default instance
284280
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
285-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
281+
return new[] { serverSpn, serverSpnWithDefaultPort };
286282
}
287283
// else Named Pipes do not need to valid port
288284

289285
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
290-
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
286+
return new[] { serverSpn };
291287
}
292288

293289
/// <summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ internal static void Assert(string message)
134134

135135
private bool _is2022 = false;
136136

137-
private byte[][] _sniSpnBuffer = null;
137+
private string[] _sniSpn = null;
138138

139139
// SqlStatistics
140140
private SqlStatistics _statistics = null;
@@ -404,7 +404,7 @@ internal void Connect(
404404
}
405405
else
406406
{
407-
_sniSpnBuffer = null;
407+
_sniSpn = null;
408408
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
409409
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
410410
}
@@ -416,7 +416,7 @@ internal void Connect(
416416
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
417417
}
418418

419-
_sniSpnBuffer = null;
419+
_sniSpn = null;
420420
_authenticationProvider = null;
421421

422422
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
@@ -455,7 +455,7 @@ internal void Connect(
455455
serverInfo.ExtendedServerName,
456456
timeout,
457457
out instanceName,
458-
ref _sniSpnBuffer,
458+
ref _sniSpn,
459459
false,
460460
true,
461461
fParallel,
@@ -468,7 +468,7 @@ internal void Connect(
468468
hostNameInCertificate,
469469
serverCertificateFilename);
470470

471-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
471+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);
472472

473473
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
474474
{
@@ -554,7 +554,7 @@ internal void Connect(
554554
_physicalStateObj.CreatePhysicalSNIHandle(
555555
serverInfo.ExtendedServerName,
556556
timeout, out instanceName,
557-
ref _sniSpnBuffer,
557+
ref _sniSpn,
558558
true,
559559
true,
560560
fParallel,
@@ -567,7 +567,7 @@ internal void Connect(
567567
hostNameInCertificate,
568568
serverCertificateFilename);
569569

570-
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
570+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);
571571

572572
if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
573573
{
@@ -576,6 +576,8 @@ internal void Connect(
576576
ThrowExceptionAndWarning(_physicalStateObj);
577577
}
578578

579+
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);
580+
579581
uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);
580582

581583
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
@@ -12850,7 +12852,7 @@ internal string TraceString()
1285012852
_fMARS ? bool.TrueString : bool.FalseString,
1285112853
null == _sessionPool ? "(null)" : _sessionPool.TraceString(),
1285212854
_is2005 ? bool.TrueString : bool.FalseString,
12853-
null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
12855+
null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
1285412856
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
1285512857
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
1285612858
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
186186
string serverName,
187187
TimeoutTimer timeout,
188188
out byte[] instanceName,
189-
ref byte[][] spnBuffer,
189+
ref string[] spn,
190190
bool flushCache,
191191
bool async,
192192
bool fParallel,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
8181
string serverName,
8282
TimeoutTimer timeout,
8383
out byte[] instanceName,
84-
ref byte[][] spnBuffer,
84+
ref string[] spn,
8585
bool flushCache,
8686
bool async,
8787
bool parallel,
@@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
9494
string hostNameInCertificate,
9595
string serverCertificateFilename)
9696
{
97-
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
97+
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
9898
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
9999
hostNameInCertificate, serverCertificateFilename);
100100

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ internal override void CreatePhysicalSNIHandle(
143143
string serverName,
144144
TimeoutTimer timeout,
145145
out byte[] instanceName,
146-
ref byte[][] spnBuffer,
146+
ref string[] spn,
147147
bool flushCache,
148148
bool async,
149149
bool fParallel,
@@ -156,31 +156,28 @@ internal override void CreatePhysicalSNIHandle(
156156
string hostNameInCertificate,
157157
string serverCertificateFilename)
158158
{
159-
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
160-
spnBuffer = new byte[1][];
161159
if (isIntegratedSecurity)
162160
{
163161
// now allocate proper length of buffer
164162
if (!string.IsNullOrEmpty(serverSPN))
165163
{
166164
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
167-
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
168-
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
169-
spnBuffer[0] = srvSPN;
170165
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
171166
}
172167
else
173168
{
174-
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
169+
// This will signal to the interop layer that we need to retrieve the SPN
170+
serverSPN = string.Empty;
175171
}
176172
}
177173

178174
SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
179175
SQLDNSInfo cachedDNSInfo;
180176
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
181177

182-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
178+
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
183179
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
180+
spn = new[] { serverSPN.TrimEnd() };
184181
}
185182

186183
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)

0 commit comments

Comments
 (0)