Skip to content

Commit 4c3c3f2

Browse files
committed
Merge remote-tracking branch 'upstream/main' into sspi-writer
2 parents 7c3335e + a34ec48 commit 4c3c3f2

File tree

13 files changed

+106
-65
lines changed

13 files changed

+106
-65
lines changed

src/Microsoft.Data.SqlClient.sln

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ EndProject
196196
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{4F3CD363-B1E6-4D6D-9466-97D78A56BE45}"
197197
ProjectSection(SolutionItems) = preProject
198198
Directory.Build.props = Directory.Build.props
199-
NuGet.config = NuGet.config
200199
EndProjectSection
201200
EndProject
202201
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlServer.Server", "Microsoft.SqlServer.Server\Microsoft.SqlServer.Server.csproj", "{A314812A-7820-4565-A2A8-ABBE391C11E4}"

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
<Product>Core $(BaseProduct)</Product>
1919
<EnableTrimAnalyzer Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0'))">true</EnableTrimAnalyzer>
2020
<NoWarn>$(NoWarn);IL2026;IL2057;IL2072;IL2075</NoWarn>
21+
<RootNamespace />
2122
</PropertyGroup>
2223
<PropertyGroup>
2324
<TargetFrameworkMonikerAssemblyAttributesPath>$([System.IO.Path]::Combine('$(IntermediateOutputPath)','$(TargetFrameworkMoniker).AssemblyAttributes$(DefaultLanguageSourceExtension)'))</TargetFrameworkMonikerAssemblyAttributesPath>
@@ -935,6 +936,7 @@
935936
</EmbeddedResource>
936937
<EmbeddedResource Include="$(CommonSourceRoot)Resources\$(ResxFileName).*.resx">
937938
<Link>Resources\%(RecursiveDir)%(Filename)%(Extension)</Link>
939+
<LogicalName>Microsoft.Data.SqlClient.Resources.%(Filename).resources</LogicalName>
938940
</EmbeddedResource>
939941
<EmbeddedResource Include="Resources\Microsoft.Data.SqlClient.SqlMetaData.xml">
940942
<LogicalName>Microsoft.Data.SqlClient.SqlMetaData.xml</LogicalName>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ internal static SNIHandle CreateConnectionHandle(
196196
}
197197

198198
SNIHandle sniHandle = null;
199-
switch (details._connectionProtocol)
199+
switch (details.ResolvedProtocol)
200200
{
201201
case DataSource.Protocol.Admin:
202202
case DataSource.Protocol.None: // default to using tcp if no protocol is provided
@@ -208,7 +208,7 @@ internal static SNIHandle CreateConnectionHandle(
208208
sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst, hostNameInCertificate, serverCertificateFilename);
209209
break;
210210
default:
211-
Debug.Fail($"Unexpected connection protocol: {details._connectionProtocol}");
211+
Debug.Fail($"Unexpected connection protocol: {details.ResolvedProtocol}");
212212
break;
213213
}
214214

@@ -244,11 +244,11 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
244244
}
245245
else if (!string.IsNullOrWhiteSpace(dataSource.InstanceName))
246246
{
247-
postfix = dataSource._connectionProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName;
247+
postfix = dataSource.ResolvedProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName;
248248
}
249249

250250
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerName {0}, InstanceName {1}, Port {2}, postfix {3}", dataSource?.ServerName, dataSource?.InstanceName, dataSource?.Port, postfix);
251-
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
251+
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
252252
}
253253

254254
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
@@ -326,7 +326,7 @@ private static SNITCPHandle CreateTcpHandle(
326326
}
327327

328328
int port = -1;
329-
bool isAdminConnection = details._connectionProtocol == DataSource.Protocol.Admin;
329+
bool isAdminConnection = details.ResolvedProtocol == DataSource.Protocol.Admin;
330330
if (details.IsSsrpRequired)
331331
{
332332
try
@@ -439,8 +439,6 @@ internal class DataSource
439439

440440
internal enum Protocol { TCP, NP, None, Admin };
441441

442-
internal Protocol _connectionProtocol = Protocol.None;
443-
444442
/// <summary>
445443
/// Provides the HostName of the server to connect to for TCP protocol.
446444
/// This information is also used for finding the SPN of SqlServer
@@ -472,6 +470,12 @@ internal enum Protocol { TCP, NP, None, Admin };
472470
/// </summary>
473471
internal string PipeHostName { get; private set; }
474472

473+
/// <summary>
474+
/// Gets or sets the protocol that was resolved from the connection string. If this is
475+
/// <see cref="Protocol.None"/>, the protocol could not reliably be determined.
476+
/// </summary>
477+
internal Protocol ResolvedProtocol { get; private set; }
478+
475479
private string _workingDataSource;
476480
private string _dataSourceAfterTrimmingProtocol;
477481

@@ -488,16 +492,16 @@ private DataSource(string dataSource)
488492

489493
PopulateProtocol();
490494

491-
_dataSourceAfterTrimmingProtocol = (firstIndexOfColon > -1) && _connectionProtocol != Protocol.None
495+
_dataSourceAfterTrimmingProtocol = (firstIndexOfColon > -1) && ResolvedProtocol != Protocol.None
492496
? _workingDataSource.Substring(firstIndexOfColon + 1).Trim() : _workingDataSource;
493497

494498
if (_dataSourceAfterTrimmingProtocol.Contains(Slash)) // Pipe paths only allow back slashes
495499
{
496-
if (_connectionProtocol == Protocol.None)
500+
if (ResolvedProtocol == Protocol.None)
497501
ReportSNIError(SNIProviders.INVALID_PROV);
498-
else if (_connectionProtocol == Protocol.NP)
502+
else if (ResolvedProtocol == Protocol.NP)
499503
ReportSNIError(SNIProviders.NP_PROV);
500-
else if (_connectionProtocol == Protocol.TCP)
504+
else if (ResolvedProtocol == Protocol.TCP)
501505
ReportSNIError(SNIProviders.TCP_PROV);
502506
}
503507
}
@@ -508,25 +512,25 @@ private void PopulateProtocol()
508512

509513
if (splitByColon.Length <= 1)
510514
{
511-
_connectionProtocol = Protocol.None;
515+
ResolvedProtocol = Protocol.None;
512516
}
513517
else
514518
{
515519
// We trim before switching because " tcp : server , 1433 " is a valid data source
516520
switch (splitByColon[0].Trim())
517521
{
518522
case TdsEnums.TCP:
519-
_connectionProtocol = Protocol.TCP;
523+
ResolvedProtocol = Protocol.TCP;
520524
break;
521525
case TdsEnums.NP:
522-
_connectionProtocol = Protocol.NP;
526+
ResolvedProtocol = Protocol.NP;
523527
break;
524528
case TdsEnums.ADMIN:
525-
_connectionProtocol = Protocol.Admin;
529+
ResolvedProtocol = Protocol.Admin;
526530
break;
527531
default:
528532
// None of the supported protocols were found. This may be a IPv6 address
529-
_connectionProtocol = Protocol.None;
533+
ResolvedProtocol = Protocol.None;
530534
break;
531535
}
532536
}
@@ -610,7 +614,7 @@ private void InferLocalServerName()
610614
// If Server name is empty or localhost, then use "localhost"
611615
if (string.IsNullOrEmpty(ServerName) || IsLocalHost(ServerName) ||
612616
(Environment.MachineName.Equals(ServerName, StringComparison.CurrentCultureIgnoreCase) &&
613-
_connectionProtocol == Protocol.Admin))
617+
ResolvedProtocol == Protocol.Admin))
614618
{
615619
// For DAC use "localhost" instead of the server name.
616620
ServerName = DefaultHostName;
@@ -642,11 +646,11 @@ private bool InferConnectionDetails()
642646
}
643647

644648
// For Tcp and Only Tcp are parameters allowed.
645-
if (_connectionProtocol == Protocol.None)
649+
if (ResolvedProtocol == Protocol.None)
646650
{
647-
_connectionProtocol = Protocol.TCP;
651+
ResolvedProtocol = Protocol.TCP;
648652
}
649-
else if (_connectionProtocol != Protocol.TCP)
653+
else if (ResolvedProtocol != Protocol.TCP)
650654
{
651655
// Parameter has been specified for non-TCP protocol. This is not allowed.
652656
ReportSNIError(SNIProviders.INVALID_PROV);
@@ -704,15 +708,15 @@ private void ReportSNIError(SNIProviders provider)
704708
private bool InferNamedPipesInformation()
705709
{
706710
// If we have a datasource beginning with a pipe or we have already determined that the protocol is Named Pipe
707-
if (_dataSourceAfterTrimmingProtocol.StartsWith(PipeBeginning, StringComparison.Ordinal) || _connectionProtocol == Protocol.NP)
711+
if (_dataSourceAfterTrimmingProtocol.StartsWith(PipeBeginning, StringComparison.Ordinal) || ResolvedProtocol == Protocol.NP)
708712
{
709713
// If the data source starts with "np:servername"
710714
if (!_dataSourceAfterTrimmingProtocol.Contains(PipeBeginning))
711715
{
712716
// Assuming that user did not change default NamedPipe name, if the datasource is in the format servername\instance,
713717
// separate servername and instance and prepend instance with MSSQL$ and append default pipe path
714718
// https://learn.microsoft.com/en-us/sql/tools/configuration-manager/named-pipes-properties?view=sql-server-ver16
715-
if (_dataSourceAfterTrimmingProtocol.Contains(PathSeparator) && _connectionProtocol == Protocol.NP)
719+
if (_dataSourceAfterTrimmingProtocol.Contains(PathSeparator) && ResolvedProtocol == Protocol.NP)
716720
{
717721
string[] tokensByBackSlash = _dataSourceAfterTrimmingProtocol.Split(BackSlashCharacter);
718722
if (tokensByBackSlash.Length == 2)
@@ -799,11 +803,11 @@ private bool InferNamedPipesInformation()
799803
}
800804

801805
// DataSource is something like "\\pipename"
802-
if (_connectionProtocol == Protocol.None)
806+
if (ResolvedProtocol == Protocol.None)
803807
{
804-
_connectionProtocol = Protocol.NP;
808+
ResolvedProtocol = Protocol.NP;
805809
}
806-
else if (_connectionProtocol != Protocol.NP)
810+
else if (ResolvedProtocol != Protocol.NP)
807811
{
808812
// In case the path began with a "\\" and protocol was not Named Pipes
809813
ReportSNIError(SNIProviders.NP_PROV);

src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.Net.Sdk">
22
<PropertyGroup>
33
<ProjectGuid>{407890AC-9876-4FEF-A6F1-F36A876BAADE}</ProjectGuid>
4-
<RootNamespace>Microsoft.Data.SqlClient</RootNamespace>
4+
<RootNamespace></RootNamespace>
55
<TargetFramework>net462</TargetFramework>
66
<EnableLocalAppContext>true</EnableLocalAppContext>
77
<AssemblyName>Microsoft.Data.SqlClient</AssemblyName>
@@ -731,6 +731,7 @@
731731
</EmbeddedResource>
732732
<EmbeddedResource Include="$(CommonSourceRoot)Resources\$(ResxFileName).*.resx">
733733
<Link>Resources\%(RecursiveDir)%(Filename)%(Extension)</Link>
734+
<LogicalName>Microsoft.Data.SqlClient.Resources.%(Filename).resources</LogicalName>
734735
</EmbeddedResource>
735736
<EmbeddedResource Include="Resources\Microsoft.Data.SqlClient.SqlMetaData.xml">
736737
<LogicalName>Microsoft.Data.SqlClient.SqlMetaData.xml</LogicalName>

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
120120
using CancellationTokenSource cts = new();
121121

122122
// Use Connection timeout value to cancel token acquire request after certain period of time.
123-
cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
123+
int timeout = parameters.ConnectionTimeout * 1000; // Convert to milliseconds
124+
if (timeout > 0) // if ConnectionTimeout is 0 or the millis overflows an int, no need to set CancelAfter
125+
{
126+
cts.CancelAfter(timeout);
127+
}
124128

125129
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix, StringComparison.Ordinal) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
126130
string[] scopes = new string[] { scope };

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public static class DataTestUtility
6767
public static readonly string EnclaveAzureDatabaseConnString = null;
6868
public static bool ManagedIdentitySupported = true;
6969
public static string AADAccessToken = null;
70+
public static bool SupportsSystemAssignedManagedIdentity = false;
7071
public static string AADSystemIdentityAccessToken = null;
7172
public static string AADUserIdentityAccessToken = null;
7273
public const string ApplicationClientId = "2fd908ad-0664-4344-b9be-cd3e8b574c38";
@@ -108,6 +109,15 @@ public static bool IsAzureSynapse
108109
}
109110
}
110111

112+
public static bool TcpConnectionStringDoesNotUseAadAuth
113+
{
114+
get
115+
{
116+
SqlConnectionStringBuilder builder = new (TCPConnectionString);
117+
return builder.Authentication == SqlAuthenticationMethod.SqlPassword || builder.Authentication == SqlAuthenticationMethod.NotSpecified;
118+
}
119+
}
120+
111121
public static string SQLServerVersion
112122
{
113123
get
@@ -645,7 +655,7 @@ public static string GetAccessToken()
645655

646656
public static string GetSystemIdentityAccessToken()
647657
{
648-
if (true == ManagedIdentitySupported && null == AADSystemIdentityAccessToken && IsAADPasswordConnStrSetup())
658+
if (ManagedIdentitySupported && SupportsSystemAssignedManagedIdentity && null == AADSystemIdentityAccessToken && IsAADPasswordConnStrSetup())
649659
{
650660
AADSystemIdentityAccessToken = AADUtility.GetManagedIdentityToken().GetAwaiter().GetResult();
651661
if (AADSystemIdentityAccessToken == null)
@@ -658,7 +668,7 @@ public static string GetSystemIdentityAccessToken()
658668

659669
public static string GetUserIdentityAccessToken()
660670
{
661-
if (true == ManagedIdentitySupported && null == AADUserIdentityAccessToken && IsAADPasswordConnStrSetup())
671+
if (ManagedIdentitySupported && null == AADUserIdentityAccessToken && IsAADPasswordConnStrSetup())
662672
{
663673
// Pass User Assigned Managed Identity Client Id here.
664674
AADUserIdentityAccessToken = AADUtility.GetManagedIdentityToken(UserManagedIdentityClientId).GetAwaiter().GetResult();

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AdapterTest/AdapterTest.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,14 @@ public void UpdateOffsetTest()
10591059
}
10601060
}
10611061

1062-
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
1062+
public static bool CanRunSchemaTests()
1063+
{
1064+
return DataTestUtility.AreConnStringsSetup() &&
1065+
// Tests switch to master database, which is not guaranteed when using AAD auth
1066+
DataTestUtility.TcpConnectionStringDoesNotUseAadAuth;
1067+
}
1068+
1069+
[ConditionalFact(nameof(CanRunSchemaTests))]
10631070
public void SelectAllTest()
10641071
{
10651072
// Test exceptions

0 commit comments

Comments
 (0)