diff --git a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs new file mode 100644 index 0000000000..5827c7ee1f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs @@ -0,0 +1,355 @@ +using System; +using System.Collections.Generic; +using System.Reflection; + +using Xunit; + +namespace Microsoft.Data.SqlClient.Tests.Common; + +// This class provides read/write access to LocalAppContextSwitches values +// for the duration of a test. It is intended to be constructed at the start +// of a test and disposed at the end. It captures the original values of +// the switches and restores them when disposed. +// +// This follows the RAII pattern to ensure that the switches are always +// restored, which is important for global state like LocalAppContextSwitches. +// +// https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization +// +public sealed class LocalAppContextSwitchesHelper : IDisposable +{ + #region Public Types + + // This enum is used to represent the state of a switch. + // + // It is a copy of the Tristate enum from LocalAppContextSwitches. + // + public enum Tristate : byte + { + NotInitialized = 0, + False = 1, + True = 2 + } + + #endregion + + #region Construction + + // Construct to capture all existing switch values. + // + // Fails the test if any values cannot be captured. + // + public LocalAppContextSwitchesHelper() + { + // Acquire a handle to the LocalAppContextSwitches type. + var assembly = typeof(SqlCommandBuilder).Assembly; + var switchesType = assembly.GetType( + "Microsoft.Data.SqlClient.LocalAppContextSwitches"); + if (switchesType == null) + { + Assert.Fail("Unable to find LocalAppContextSwitches type."); + } + + // A local helper to acquire a handle to a property. + void InitProperty(string name, out PropertyInfo property) + { + var prop = switchesType.GetProperty( + name, BindingFlags.Public | BindingFlags.Static); + if (prop == null) + { + Assert.Fail($"Unable to find {name} property."); + } + property = prop; + } + + // Acquire handles to all of the public properties of + // LocalAppContextSwitches. + InitProperty( + "LegacyRowVersionNullBehavior", + out _legacyRowVersionNullBehaviorProperty); + InitProperty( + "SuppressInsecureTLSWarning", + out _suppressInsecureTLSWarningProperty); + InitProperty( + "MakeReadAsyncBlocking", + out _makeReadAsyncBlockingProperty); + InitProperty( + "UseMinimumLoginTimeout", + out _useMinimumLoginTimeoutProperty); + InitProperty( + "LegacyVarTimeZeroScaleBehaviour", + out _legacyVarTimeZeroScaleBehaviourProperty); + InitProperty( + "UseCompatibilityProcessSni", + out _useCompatProcessSniProperty); + InitProperty( + "UseCompatibilityAsyncBehaviour", + out _useCompatAsyncBehaviourProperty); +#if NETFRAMEWORK + InitProperty( + "DisableTNIRByDefault", + out _disableTNIRByDefaultProperty); +#endif + + // A local helper to capture the original value of a switch. + void InitField(string name, out FieldInfo field, out Tristate value) + { + var fieldInfo = + switchesType.GetField( + name, BindingFlags.NonPublic | BindingFlags.Static); + if (fieldInfo == null) + { + Assert.Fail($"Unable to find {name} field."); + } + field = fieldInfo; + value = GetValue(field); + } + + // Capture the original value of each switch. + InitField( + "s_legacyRowVersionNullBehavior", + out _legacyRowVersionNullBehaviorField, + out _legacyRowVersionNullBehaviorOriginal); + + InitField( + "s_suppressInsecureTLSWarning", + out _suppressInsecureTLSWarningField, + out _suppressInsecureTLSWarningOriginal); + + InitField( + "s_makeReadAsyncBlocking", + out _makeReadAsyncBlockingField, + out _makeReadAsyncBlockingOriginal); + + InitField( + "s_useMinimumLoginTimeout", + out _useMinimumLoginTimeoutField, + out _useMinimumLoginTimeoutOriginal); + + InitField( + "s_legacyVarTimeZeroScaleBehaviour", + out _legacyVarTimeZeroScaleBehaviourField, + out _legacyVarTimeZeroScaleBehaviourOriginal); + + InitField( + "s_useCompatProcessSni", + out _useCompatProcessSniField, + out _useCompatProcessSniOriginal); + + InitField( + "s_useCompatAsyncBehaviour", + out _useCompatAsyncBehaviourField, + out _useCompatAsyncBehaviourOriginal); + +#if NETFRAMEWORK + InitField( + "s_disableTNIRByDefault", + out _disableTNIRByDefaultField, + out _disableTNIRByDefaultOriginal); +#endif + } + + // Disposal restores all original switch values as a best effort. + public void Dispose() + { + List failedFields = new(); + + void RestoreField(FieldInfo field, Tristate value) + { + try + { + SetValue(field, value); + } + catch (Exception) + { + failedFields.Add(field.Name); + } + } + + RestoreField( + _legacyRowVersionNullBehaviorField, + _legacyRowVersionNullBehaviorOriginal); + RestoreField( + _suppressInsecureTLSWarningField, + _suppressInsecureTLSWarningOriginal); + RestoreField( + _makeReadAsyncBlockingField, + _makeReadAsyncBlockingOriginal); + RestoreField( + _useMinimumLoginTimeoutField, + _useMinimumLoginTimeoutOriginal); + RestoreField( + _legacyVarTimeZeroScaleBehaviourField, + _legacyVarTimeZeroScaleBehaviourOriginal); + RestoreField( + _useCompatProcessSniField, + _useCompatProcessSniOriginal); + RestoreField( + _useCompatAsyncBehaviourField, + _useCompatAsyncBehaviourOriginal); +#if NETFRAMEWORK + RestoreField( + _disableTNIRByDefaultField, + _disableTNIRByDefaultOriginal); +#endif + if (failedFields.Count > 0) + { + Assert.Fail( + $"Failed to restore the following fields: " + + string.Join(", ", failedFields)); + } + } + + #endregion + + #region Public Properties + + // These properties expose the like-named LocalAppContextSwitches + // properties. + public bool LegacyRowVersionNullBehavior + { + get => (bool)_legacyRowVersionNullBehaviorProperty.GetValue(null); + } + public bool SuppressInsecureTLSWarning + { + get => (bool)_suppressInsecureTLSWarningProperty.GetValue(null); + } + public bool MakeReadAsyncBlocking + { + get => (bool)_makeReadAsyncBlockingProperty.GetValue(null); + } + public bool UseMinimumLoginTimeout + { + get => (bool)_useMinimumLoginTimeoutProperty.GetValue(null); + } + public bool LegacyVarTimeZeroScaleBehaviour + { + get => (bool)_legacyVarTimeZeroScaleBehaviourProperty.GetValue(null); + } + public bool UseCompatibilityProcessSni + { + get => (bool)_useCompatProcessSniProperty.GetValue(null); + } + public bool UseCompatibilityAsyncBehaviour + { + get => (bool)_useCompatAsyncBehaviourProperty.GetValue(null); + } +#if NETFRAMEWORK + public bool DisableTNIRByDefault + { + get => (bool)_disableTNIRByDefaultProperty.GetValue(null); + } +#endif + + // These properties get or set the like-named underlying switch field value. + // + // They all fail the test if the value cannot be retrieved or set. + + public Tristate LegacyRowVersionNullBehaviorField + { + get => GetValue(_legacyRowVersionNullBehaviorField); + set => SetValue(_legacyRowVersionNullBehaviorField, value); + } + + public Tristate SuppressInsecureTLSWarningField + { + get => GetValue(_suppressInsecureTLSWarningField); + set => SetValue(_suppressInsecureTLSWarningField, value); + } + + public Tristate MakeReadAsyncBlockingField + { + get => GetValue(_makeReadAsyncBlockingField); + set => SetValue(_makeReadAsyncBlockingField, value); + } + + public Tristate UseMinimumLoginTimeoutField + { + get => GetValue(_useMinimumLoginTimeoutField); + set => SetValue(_useMinimumLoginTimeoutField, value); + } + + public Tristate LegacyVarTimeZeroScaleBehaviourField + { + get => GetValue(_legacyVarTimeZeroScaleBehaviourField); + set => SetValue(_legacyVarTimeZeroScaleBehaviourField, value); + } + + public Tristate UseCompatProcessSniField + { + get => GetValue(_useCompatProcessSniField); + set => SetValue(_useCompatProcessSniField, value); + } + + public Tristate UseCompatAsyncBehaviourField + { + get => GetValue(_useCompatAsyncBehaviourField); + set => SetValue(_useCompatAsyncBehaviourField, value); + } + +#if NETFRAMEWORK + public Tristate DisableTNIRByDefaultField + { + get => GetValue(_disableTNIRByDefaultField); + set => SetValue(_disableTNIRByDefaultField, value); + } +#endif + + #endregion + + #region Private Helpers + + private static Tristate GetValue(FieldInfo field) + { + var value = field.GetValue(null); + if (value is null) + { + Assert.Fail($"Field {field.Name} has a null value."); + } + + return (Tristate)value; + } + + private static void SetValue(FieldInfo field, Tristate value) + { + field.SetValue(null, (byte)value); + } + + #endregion + + #region Private Members + + // These fields are used to expose LocalAppContextSwitches's properties. + private readonly PropertyInfo _legacyRowVersionNullBehaviorProperty; + private readonly PropertyInfo _suppressInsecureTLSWarningProperty; + private readonly PropertyInfo _makeReadAsyncBlockingProperty; + private readonly PropertyInfo _useMinimumLoginTimeoutProperty; + private readonly PropertyInfo _legacyVarTimeZeroScaleBehaviourProperty; + private readonly PropertyInfo _useCompatProcessSniProperty; + private readonly PropertyInfo _useCompatAsyncBehaviourProperty; +#if NETFRAMEWORK + private readonly PropertyInfo _disableTNIRByDefaultProperty; +#endif + + // These fields are used to capture the original switch values. + private readonly FieldInfo _legacyRowVersionNullBehaviorField; + private readonly Tristate _legacyRowVersionNullBehaviorOriginal; + private readonly FieldInfo _suppressInsecureTLSWarningField; + private readonly Tristate _suppressInsecureTLSWarningOriginal; + private readonly FieldInfo _makeReadAsyncBlockingField; + private readonly Tristate _makeReadAsyncBlockingOriginal; + private readonly FieldInfo _useMinimumLoginTimeoutField; + private readonly Tristate _useMinimumLoginTimeoutOriginal; + private readonly FieldInfo _legacyVarTimeZeroScaleBehaviourField; + private readonly Tristate _legacyVarTimeZeroScaleBehaviourOriginal; + private readonly FieldInfo _useCompatProcessSniField; + private readonly Tristate _useCompatProcessSniOriginal; + private readonly FieldInfo _useCompatAsyncBehaviourField; + private readonly Tristate _useCompatAsyncBehaviourOriginal; +#if NETFRAMEWORK + private readonly FieldInfo _disableTNIRByDefaultField; + private readonly Tristate _disableTNIRByDefaultOriginal; +#endif + + #endregion +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs index 99f68c8073..b415e5aba2 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs @@ -11,13 +11,17 @@ namespace Microsoft.Data.SqlClient.Tests public class LocalAppContextSwitchesTests { [Theory] - [InlineData("SuppressInsecureTLSWarning", false)] [InlineData("LegacyRowVersionNullBehavior", false)] + [InlineData("SuppressInsecureTLSWarning", false)] [InlineData("MakeReadAsyncBlocking", false)] [InlineData("UseMinimumLoginTimeout", true)] + [InlineData("LegacyVarTimeZeroScaleBehaviour", true)] [InlineData("UseCompatibilityProcessSni", false)] [InlineData("UseCompatibilityAsyncBehaviour", false)] [InlineData("UseConnectionPoolV2", false)] + #if NETFRAMEWORK + [InlineData("DisableTNIRByDefault", false)] + #endif public void DefaultSwitchValue(string property, bool expectedDefaultValue) { var switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index 646535116d..c5ee1c76a3 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -10,6 +10,9 @@ $(BinFolder)$(Configuration).$(Platform).$(AssemblyName) true + + + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index 12baf6f2e9..288586fb17 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -219,7 +219,7 @@ public static void BetweenAsyncAttentionPacket() var attentionPacket = CreatePacket(13, 6); var input = new List { normalPacket, attentionPacket }; - var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync: true); + using var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync: true); for (int index = 0; index < input.Count; index++) { @@ -248,7 +248,7 @@ public static void MultipleFullPacketsInRemainderAreSplitCorrectly() List input = SplitPacket(CombinePackets(expected), 700); - var stateObject = new TdsParserStateObject(input, dataSize, isAsync: false); + using var stateObject = new TdsParserStateObject(input, dataSize, isAsync: false); var output = MultiplexPacketList(false, dataSize, input); @@ -258,7 +258,7 @@ public static void MultipleFullPacketsInRemainderAreSplitCorrectly() [ExcludeFromCodeCoverage] private static List MultiplexPacketList(bool isAsync, int dataSize, List input) { - var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync); + using var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync); var output = new List(); for (int index = 0; index < input.Count; index++) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs index 94e285b596..47c2db8f47 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs @@ -10,6 +10,8 @@ using System.Reflection; using Xunit; +using SwitchesHelper = Microsoft.Data.SqlClient.Tests.Common.LocalAppContextSwitchesHelper; + namespace Microsoft.Data.SqlClient.Tests { public class SqlParameterTests @@ -1945,89 +1947,29 @@ private enum Int64Enum : long [InlineData(5, 5, false)] [InlineData(6, 6, false)] [InlineData(7, 7, false)] - [InlineData(null, 7, null)] - [InlineData(0, 7, null)] - [InlineData(1, 1, null)] - [InlineData(2, 2, null)] - [InlineData(3, 3, null)] - [InlineData(4, 4, null)] - [InlineData(5, 5, null)] - [InlineData(6, 6, null)] - [InlineData(7, 7, null)] - public void SqlDatetime2Scale_Legacy(int? setScale, byte outputScale, bool? legacyVarTimeZeroScaleSwitchValue) + public void SqlDatetime2Scale_Legacy(int? setScale, byte outputScale, bool legacyVarTimeZeroScaleSwitchValue) { lock (_parameterLegacyScaleLock) { - var originalLegacyVarTimeZeroScaleSwitchValue = SetLegacyVarTimeZeroScaleBehaviour(legacyVarTimeZeroScaleSwitchValue); - try - { - var parameter = new SqlParameter - { - DbType = DbType.DateTime2 - }; - if (setScale.HasValue) - { - parameter.Scale = (byte)setScale.Value; - } + using SwitchesHelper switches = new SwitchesHelper(); + switches.LegacyVarTimeZeroScaleBehaviourField = + legacyVarTimeZeroScaleSwitchValue + ? SwitchesHelper.Tristate.True + : SwitchesHelper.Tristate.False; - var actualScale = (byte)typeof(SqlParameter).GetMethod("GetActualScale", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(parameter, null); - - Assert.Equal(outputScale, actualScale); - } - - finally + var parameter = new SqlParameter + { + DbType = DbType.DateTime2 + }; + if (setScale.HasValue) { - SetLegacyVarTimeZeroScaleBehaviour(originalLegacyVarTimeZeroScaleSwitchValue); + parameter.Scale = (byte)setScale.Value; } - } - } - - [Fact] - public void SetLegacyVarTimeZeroScaleBehaviour_Defaults_to_True() - { - var legacyVarTimeZeroScaleBehaviour = (bool)LocalAppContextSwitchesType.GetProperty("LegacyVarTimeZeroScaleBehaviour", BindingFlags.Public | BindingFlags.Static).GetValue(null); - - Assert.True(legacyVarTimeZeroScaleBehaviour); - } - private static Type LocalAppContextSwitchesType => typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); + var actualScale = (byte)typeof(SqlParameter).GetMethod("GetActualScale", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(parameter, null); - private static bool? SetLegacyVarTimeZeroScaleBehaviour(bool? value) - { - const string LegacyVarTimeZeroScaleBehaviourSwitchname = @"Switch.Microsoft.Data.SqlClient.LegacyVarTimeZeroScaleBehaviour"; - - //reset internal state to "NotInitialized" so we pick up the value via AppContext - FieldInfo switchField = LocalAppContextSwitchesType.GetField("s_legacyVarTimeZeroScaleBehaviour", BindingFlags.NonPublic | BindingFlags.Static); - switchField.SetValue(null, (byte)0); - - bool? returnValue = null; - if (AppContext.TryGetSwitch(LegacyVarTimeZeroScaleBehaviourSwitchname, out var originalValue)) - { - returnValue = originalValue; - } - - if (value.HasValue) - { - AppContext.SetSwitch(LegacyVarTimeZeroScaleBehaviourSwitchname, value.Value); + Assert.Equal(outputScale, actualScale); } - else - { - //need to remove the switch value via reflection as AppContext does not expose a means to do that. -#if NET - var switches = typeof(AppContext).GetField("s_switches", BindingFlags.NonPublic | BindingFlags.Static).GetValue(null); - if (switches is not null) //may be null if not initialised yet - { - MethodInfo removeMethod = switches.GetType().GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, new Type[] { typeof(string) }); - removeMethod.Invoke(switches, new[] { LegacyVarTimeZeroScaleBehaviourSwitchname }); - } -#else - var switches = typeof(AppContext).GetField("s_switchMap", BindingFlags.NonPublic | BindingFlags.Static).GetValue(null); - MethodInfo removeMethod = switches.GetType().GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance); - removeMethod.Invoke(switches, new[] { LegacyVarTimeZeroScaleBehaviourSwitchname }); -#endif - } - - return returnValue; } } } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs index e448b66b83..1d9ebb315d 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -8,13 +8,14 @@ using System.Reflection; using Microsoft.Data.SqlClient.Tests; +using SwitchesHelper = Microsoft.Data.SqlClient.Tests.Common.LocalAppContextSwitchesHelper; + namespace Microsoft.Data.SqlClient { internal struct PacketHandle { } - - internal partial class TdsParserStateObject + internal partial class TdsParserStateObject : IDisposable { internal int ObjectID = 1; @@ -103,6 +104,7 @@ internal void MoveNext() public int _inBytesRead; public int _inBytesUsed; public byte[] _inBuff; + [DebuggerStepThrough] public TdsParserStateObject(List input, int packetSize, bool isAsync) { @@ -114,6 +116,13 @@ public TdsParserStateObject(List input, int packetSize, bool isAsync _snapshot = new Snapshot(); } } + + [DebuggerStepThrough] + public void Dispose() + { + LocalAppContextSwitches.Dispose(); + } + [DebuggerStepThrough] private uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) { @@ -145,30 +154,9 @@ private void AssertValidState() { } [DebuggerStepThrough] private void AddError(object value) => throw new Exception(value as string ?? "AddError"); - internal static class LocalAppContextSwitches - { - public static bool UseCompatibilityProcessSni - { - get - { - var switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - - return (bool)switchesType.GetProperty(nameof(UseCompatibilityProcessSni), BindingFlags.Public | BindingFlags.Static).GetValue(null); - } - } - - public static bool UseCompatibilityAsyncBehaviour - { - get - { - var switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - - return (bool)switchesType.GetProperty(nameof(UseCompatibilityAsyncBehaviour), BindingFlags.Public | BindingFlags.Static).GetValue(null); - } - } - } + private SwitchesHelper LocalAppContextSwitches = new(); - #if NETFRAMEWORK +#if NETFRAMEWORK private SniNativeWrapperImpl _native; internal SniNativeWrapperImpl SniNativeWrapper { @@ -189,7 +177,7 @@ internal class SniNativeWrapperImpl internal uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) => _parent.SniPacketGetData(packet, inBuff, ref dataSize); } - #endif +#endif } internal static class TdsEnums diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 5dfad69698..c930100b95 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -12,6 +12,9 @@ $(BinFolder)$(Configuration).$(Platform).$(AssemblyName) true + + + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs index e64d0fc362..a33ec50593 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs @@ -12,20 +12,14 @@ using System.Threading.Tasks; using Xunit; +using SwitchesHelper = Microsoft.Data.SqlClient.Tests.Common.LocalAppContextSwitchesHelper; + namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public static class DataReaderTest { private static readonly object s_rowVersionLock = new(); - // this enum must mirror the definition in LocalAppContextSwitches - private enum Tristate : byte - { - NotInitialized = 0, - False = 1, - True = 2 - } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void LoadReaderIntoDataTableToTestGetSchemaTable() { @@ -270,34 +264,28 @@ public static void CheckNullRowVersionIsBDNull() { lock (s_rowVersionLock) { - Tristate originalValue = SetLegacyRowVersionNullBehavior(Tristate.False); - try - { - using SqlConnection con = new(DataTestUtility.TCPConnectionString); - con.Open(); - using SqlCommand command = con.CreateCommand(); - command.CommandText = "select cast(null as rowversion) rv"; - using SqlDataReader reader = command.ExecuteReader(); - reader.Read(); - Assert.True(reader.IsDBNull(0)); - Assert.Equal(DBNull.Value, reader[0]); - var result = reader.GetValue(0); - Assert.IsType(result); - Assert.Equal(result, reader.GetFieldValue(0)); - Assert.Throws(() => reader.GetFieldValue(0)); + using SwitchesHelper helper = new(); + helper.LegacyRowVersionNullBehaviorField = SwitchesHelper.Tristate.False; - SqlBinary binary = reader.GetSqlBinary(0); - Assert.True(binary.IsNull); - - SqlBytes bytes = reader.GetSqlBytes(0); - Assert.True(bytes.IsNull); - Assert.Null(bytes.Buffer); - - } - finally - { - SetLegacyRowVersionNullBehavior(originalValue); - } + using SqlConnection con = new(DataTestUtility.TCPConnectionString); + con.Open(); + using SqlCommand command = con.CreateCommand(); + command.CommandText = "select cast(null as rowversion) rv"; + using SqlDataReader reader = command.ExecuteReader(); + reader.Read(); + Assert.True(reader.IsDBNull(0)); + Assert.Equal(DBNull.Value, reader[0]); + var result = reader.GetValue(0); + Assert.IsType(result); + Assert.Equal(result, reader.GetFieldValue(0)); + Assert.Throws(() => reader.GetFieldValue(0)); + + SqlBinary binary = reader.GetSqlBinary(0); + Assert.True(binary.IsNull); + + SqlBytes bytes = reader.GetSqlBytes(0); + Assert.True(bytes.IsNull); + Assert.Null(bytes.Buffer); } } @@ -609,38 +597,24 @@ public static void CheckLegacyNullRowVersionIsEmptyArray() { lock (s_rowVersionLock) { - Tristate originalValue = SetLegacyRowVersionNullBehavior(Tristate.True); - try - { - using SqlConnection con = new(DataTestUtility.TCPConnectionString); - con.Open(); - using SqlCommand command = con.CreateCommand(); - command.CommandText = "select cast(null as rowversion) rv"; - using SqlDataReader reader = command.ExecuteReader(); - reader.Read(); - Assert.False(reader.IsDBNull(0)); - SqlBinary value = reader.GetSqlBinary(0); - Assert.False(value.IsNull); - Assert.Equal(0, value.Length); - Assert.NotNull(value.Value); - var result = reader.GetValue(0); - Assert.IsType(result); - Assert.Equal(result, reader.GetFieldValue(0)); - } - finally - { - SetLegacyRowVersionNullBehavior(originalValue); - } - } - } + using SwitchesHelper helper = new(); + helper.LegacyRowVersionNullBehaviorField = SwitchesHelper.Tristate.True; - private static Tristate SetLegacyRowVersionNullBehavior(Tristate value) - { - Type switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - FieldInfo switchField = switchesType.GetField("s_legacyRowVersionNullBehavior", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); - Tristate originalValue = (Tristate)switchField.GetValue(null); - switchField.SetValue(null, value); - return originalValue; + using SqlConnection con = new(DataTestUtility.TCPConnectionString); + con.Open(); + using SqlCommand command = con.CreateCommand(); + command.CommandText = "select cast(null as rowversion) rv"; + using SqlDataReader reader = command.ExecuteReader(); + reader.Read(); + Assert.False(reader.IsDBNull(0)); + SqlBinary value = reader.GetSqlBinary(0); + Assert.False(value.IsNull); + Assert.Equal(0, value.Length); + Assert.NotNull(value.Value); + var result = reader.GetValue(0); + Assert.IsType(result); + Assert.Equal(result, reader.GetFieldValue(0)); + } } } }