Skip to content

Ensure correct SPN when calling SspiContextProvider #3347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@
</Compile>
<Compile Include="Microsoft\Data\Common\DbConnectionOptions.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ResolvedServerSpn.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIError.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIHandle.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This is used to hold the ServerSpn for a given connection. Most connection types have a single format, although TCP connections may allow
/// with and without a port. Depending on how the SPN is registered on the server, either one may be the correct name.
/// </summary>
/// <see href="https://learn.microsoft.com/sql/database-engine/configure-windows/register-a-service-principal-name-for-kerberos-connections?view=sql-server-ver17#spn-formats"/>
/// <param name="primary"></param>
/// <param name="secondary"></param>
/// <remarks>
/// <para>SQL Server SPN format follows these patterns:</para>
/// <list type="bullet">
/// <item>
/// <term>Default instance, no port (primary):</term>
/// <description>MSSQLSvc/fully-qualified-domain-name</description>
/// </item>
/// <item>
/// <term>Default instance, default port (secondary):</term>
/// <description>MSSQLSvc/fully-qualified-domain-name:1433</description>
/// </item>
/// <item>
/// <term>Named instance or custom port:</term>
/// <description>MSSQLSvc/fully-qualified-domain-name:port_or_instance_name</description>
/// </item>
/// </list>
/// <para>For TCP connections to named instances, the port number is used in SPN.</para>
/// <para>For Named Pipe connections to named instances, the instance name is used in SPN.</para>
/// <para>When hostname resolution fails, the user-provided hostname is used instead of FQDN.</para>
/// <para>For default instances with TCP protocol, both forms (with and without port) may be returned.</para>
/// </remarks>
internal readonly struct ResolvedServerSpn(string primary, string? secondary = null)
{
public string Primary => primary;

public string? Secondary => secondary;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Data.ProviderBase;
Expand Down Expand Up @@ -51,7 +48,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
out string resolvedSpn,
out ResolvedServerSpn resolvedSpn,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -116,12 +113,12 @@ internal static SNIHandle CreateConnectionHandle(
return sniHandle;
}

private static string GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static ResolvedServerSpn GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return serverSPN;
return new(serverSPN);
}

string hostName = dataSource.ServerName;
Expand All @@ -139,7 +136,7 @@ private static string GetSqlServerSPNs(DataSource dataSource, string serverSPN)
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
}

private static string GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static ResolvedServerSpn GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -170,12 +167,12 @@ private static string GetSqlServerSPNs(string hostNameOrAddress, string portOrIn
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return serverSpnWithDefaultPort;
return new(serverSpn, serverSpnWithDefaultPort);
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return serverSpn;
return new(serverSpn);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
out var serverSpn,
out var resolvedServerSpn,
false,
true,
fParallel,
Expand Down Expand Up @@ -540,7 +540,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
out serverSpn,
out resolvedServerSpn,
true,
true,
fParallel,
Expand Down Expand Up @@ -591,9 +591,9 @@ internal void Connect(ServerInfo serverInfo,
}
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Prelogin handshake successful");

if (_authenticationProvider is { } && serverSpn is { })
if (_authenticationProvider is { })
{
_authenticationProvider.Initialize(serverInfo, _physicalStateObj, this, serverSpn);
_authenticationProvider.Initialize(serverInfo, _physicalStateObj, this, resolvedServerSpn.Primary, resolvedServerSpn.Secondary);
}

if (_fMARS && marsCapable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -71,7 +72,7 @@ internal abstract void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
out string resolvedSpn,
out ResolvedServerSpn resolvedSpn,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
out string resolvedSpn,
out ResolvedServerSpn resolvedSpn,
bool flushCache,
bool async,
bool parallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Interop.Windows.Sni;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -144,7 +145,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
out string resolvedSpn,
out ResolvedServerSpn resolvedSpn,
bool flushCache,
bool async,
bool fParallel,
Expand Down Expand Up @@ -178,7 +179,7 @@ internal override void CreatePhysicalSNIHandle(

_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
resolvedSpn = serverSPN.TrimEnd();
resolvedSpn = new(serverSPN.TrimEnd());
}

protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

namespace Microsoft.Data.SqlClient
{
internal sealed class NegotiateSspiContextProvider : SspiContextProvider
internal sealed class NegotiateSspiContextProvider : SspiContextProvider, IDisposable
{
private NegotiateAuthentication? _negotiateAuth;

protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SspiAuthenticationParameters authParams)
{
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;

_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.Resource });
_negotiateAuth = GetNegotiateAuthenticationForParams(authParams);

var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;

Expand All @@ -31,11 +31,29 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
return true;
}

_negotiateAuth.Dispose();
_negotiateAuth = null;

return false;
}

public void Dispose()
{
_negotiateAuth?.Dispose();
}

private NegotiateAuthentication GetNegotiateAuthenticationForParams(SspiAuthenticationParameters authParams)
{
if (_negotiateAuth is { })
{
if (string.Equals(_negotiateAuth.TargetName, authParams.Resource, StringComparison.Ordinal))
{
return _negotiateAuth;
}

// Dispose of it since we're not going to use it now
_negotiateAuth?.Dispose();
}

return _negotiateAuth = new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.Resource });
}
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@ internal abstract class SspiContextProvider
private TdsParser _parser = null!;
private ServerInfo _serverInfo = null!;

private SspiAuthenticationParameters? _authParam;
private SspiAuthenticationParameters? _primaryAuthParams;
private SspiAuthenticationParameters? _secondaryAuthParams;

private protected TdsParserStateObject _physicalStateObj = null!;

#if NET
/// <remarks>
/// <see cref="SNI.ResolvedServerSpn"/> for details as to what <paramref name="primaryServerSpn"/> and <paramref name="secondaryServerSpn"/> means and why there are two.
/// </remarks>
#endif
internal void Initialize(
ServerInfo serverInfo,
TdsParserStateObject physicalStateObj,
TdsParser parser,
string serverSpn
string primaryServerSpn,
string? secondaryServerSpn = null
)
{
_parser = parser;
Expand All @@ -28,14 +35,23 @@ string serverSpn

var options = parser.Connection.ConnectionOptions;

_authParam = new SspiAuthenticationParameters(options.DataSource, serverSpn)
SqlClientEventSource.Log.StateDumpEvent("<SspiContextProvider> Initializing provider {0} with SPN={1} and alternate={2}", GetType().FullName, primaryServerSpn, secondaryServerSpn);

_primaryAuthParams = CreateAuthParams(options, primaryServerSpn);

if (secondaryServerSpn is { })
{
DatabaseName = options.InitialCatalog,
UserId = options.UserID,
Password = options.Password,
};
_secondaryAuthParams = CreateAuthParams(options, secondaryServerSpn);
}

Initialize();

static SspiAuthenticationParameters CreateAuthParams(SqlConnectionString connString, string serverSpn) => new(connString.DataSource, serverSpn)
{
DatabaseName = connString.InitialCatalog,
UserId = connString.UserID,
Password = connString.Password,
};
}

private protected virtual void Initialize()
Expand All @@ -48,11 +64,30 @@ internal void WriteSSPIContext(ReadOnlySpan<byte> receivedBuff, IBufferWriter<by
{
using var _ = TrySNIEventScope.Create(nameof(SspiContextProvider));

if (!(_authParam is { } && RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, _authParam)))
if (_primaryAuthParams is { })
{
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, _primaryAuthParams))
{
return;
}

// remove _primaryAuth from future attempts as it failed
_primaryAuthParams = null;
}

if (_secondaryAuthParams is { })
{
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, _secondaryAuthParams))
{
return;
}

// remove _secondaryAuthParams from future attempts as it failed
_secondaryAuthParams = null;
}

// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
}

private bool RunGenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SspiAuthenticationParameters authParams)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,12 @@ internal void StateDumpEvent<T0, T1>(string message, T0 args0, T1 args1)
{
StateDump(string.Format(message, args0?.ToString() ?? NullStr, args1?.ToString() ?? NullStr));
}

[NonEvent]
internal void StateDumpEvent<T0, T1, T2>(string message, T0 args0, T1 args1, T2 args2)
{
StateDump(string.Format(message, args0?.ToString() ?? NullStr, args1?.ToString() ?? NullStr, args2?.ToString()));
}
#endregion

#region SNI Trace
Expand Down
Loading