Skip to content

Commit 2f6ddef

Browse files
committed
Pass SqlAuthenticationParameters in GenerateSspiClientContext
As part of this change, the SSPIContextProvider base class now iterates through all the server names similar to what NegotiateSSPIContextProvider did.
1 parent 9450fde commit 2f6ddef

File tree

3 files changed

+81
-34
lines changed

3 files changed

+81
-34
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private void LoadSSPILibrary()
4949
}
5050
}
5151

52-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
52+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
5353
{
5454
#if NETFRAMEWORK
5555
SNIHandle handle = _physicalStateObj.Handle;
@@ -62,9 +62,9 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
6262
var sendLength = s_maxSSPILength;
6363
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);
6464

65-
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0]))
65+
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.ServerName))
6666
{
67-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
67+
return false;
6868
}
6969

7070
if (sendLength > int.MaxValue)
@@ -73,6 +73,8 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
7373
}
7474

7575
outgoingBlobWriter.Advance((int)sendLength);
76+
77+
return true;
7678
}
7779
}
7880
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#if NET
22

33
using System;
4-
using System.Net.Security;
54
using System.Buffers;
5+
using System.Net.Security;
66

77
#nullable enable
88

@@ -12,33 +12,24 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider
1212
{
1313
private NegotiateAuthentication? _negotiateAuth = null;
1414

15-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
15+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
1616
{
1717
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;
1818

19-
for (int i = 0; i < serverSpns.Length; i++)
20-
{
21-
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverSpns[i] });
22-
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
23-
24-
// Log session id, status code and the actual SPN used in the negotiation
25-
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
26-
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
27-
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
28-
{
29-
outgoingBlobWriter.Write(sendBuff);
30-
break; // Successful case, exit the loop with current SPN.
31-
}
32-
else
33-
{
34-
_negotiateAuth = null; // Reset _negotiateAuth to be generated again for next SPN.
35-
}
36-
}
19+
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.ServerName });
20+
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
21+
22+
// Log session id, status code and the actual SPN used in the negotiation
23+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
24+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
3725

38-
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
26+
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
3927
{
40-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
28+
outgoingBlobWriter.Write(sendBuff);
29+
return true; // Successful case, exit the loop with current SPN.
4130
}
31+
32+
return false;
4233
}
4334
}
4435
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,80 @@ private protected virtual void Initialize()
2626
{
2727
}
2828

29-
protected abstract void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns);
29+
protected abstract bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams);
3030

3131
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
32-
=> SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn });
32+
{
33+
if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
34+
{
35+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
36+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
37+
}
38+
}
3339

34-
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string[] serverSpns)
40+
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
3541
{
3642
using (TrySNIEventScope.Create(nameof(SSPIContextProvider)))
3743
{
38-
try
44+
foreach (var serverSpn in serverSpns)
3945
{
40-
GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpns);
41-
}
42-
catch (Exception e)
43-
{
44-
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
46+
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
47+
{
48+
return;
49+
}
4550
}
51+
52+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
53+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
54+
}
55+
}
56+
57+
private bool RunGenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
58+
{
59+
var authParams = CreateSqlAuthParams(_parser.Connection, serverSpn);
60+
61+
try
62+
{
63+
#if NET8_0_OR_GREATER
64+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, SPN={3}", GetType().FullName,
65+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, serverSpn);
66+
#else
67+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName,
68+
nameof(GenerateSspiClientContext), serverSpn);
69+
#endif
70+
71+
return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams);
72+
}
73+
catch (Exception e)
74+
{
75+
//throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
76+
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
77+
return false;
4678
}
4779
}
4880

81+
private static SqlAuthenticationParameters CreateSqlAuthParams(SqlInternalConnectionTds connection, string serverSpn)
82+
{
83+
var auth = new SqlAuthenticationParameters.Builder(
84+
authenticationMethod: connection.ConnectionOptions.Authentication,
85+
resource: null,
86+
authority: null,
87+
serverName: serverSpn,
88+
connection.ConnectionOptions.InitialCatalog);
89+
90+
if (connection.ConnectionOptions.UserID is { } userId)
91+
{
92+
auth.WithUserId(userId);
93+
}
94+
95+
if (connection.ConnectionOptions.Password is { } password)
96+
{
97+
auth.WithPassword(password);
98+
}
99+
100+
return auth;
101+
}
102+
49103
protected void SSPIError(string error, string procedure)
50104
{
51105
Debug.Assert(!ADP.IsEmpty(procedure), "TdsParser.SSPIError called with an empty or null procedure string");

0 commit comments

Comments
 (0)