diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 7d94425123..f897e55f77 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -10,6 +10,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 9ce3b09334..a3aa812c00 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -930,7 +930,6 @@ - @@ -971,9 +970,11 @@ + + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 0715bd8205..97d51aec13 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -6416,7 +6416,7 @@ internal bool DeserializeUnencryptedValue(SqlBuffer value, byte[] unencryptedByt case TdsEnums.SQLUNIQUEID: { Debug.Assert(length == 16, "invalid length for SqlGuid type!"); - value.SqlGuid = SqlTypeWorkarounds.SqlGuidCtor(unencryptedBytes, true); // doesn't copy the byte array + value.SqlGuid = SqlTypeWorkarounds.ByteArrayToSqlGuid(unencryptedBytes); break; } @@ -6437,7 +6437,7 @@ internal bool DeserializeUnencryptedValue(SqlBuffer value, byte[] unencryptedByt unencryptedBytes = bytes; } - value.SqlBinary = SqlTypeWorkarounds.SqlBinaryCtor(unencryptedBytes, true); // doesn't copy the byte array + value.SqlBinary = SqlTypeWorkarounds.ByteArrayToSqlBinary(unencryptedBytes); break; } @@ -6662,7 +6662,7 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, } else { - value.SqlBinary = SqlTypeWorkarounds.SqlBinaryCtor(b, true); // doesn't copy the byte array + value.SqlBinary = SqlTypeWorkarounds.ByteArrayToSqlBinary(b); } break; @@ -6677,7 +6677,7 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, // Internally, we use Sqlbinary to deal with varbinary data and store it in // SqlBuffer as SqlBinary value. - value.SqlBinary = SqlTypeWorkarounds.SqlBinaryCtor(b, true); + value.SqlBinary = SqlTypeWorkarounds.ByteArrayToSqlBinary(b); // Extract the metadata from the payload and set it as the vector attributes // in the SqlBuffer. This metadata is further used when constructing a SqlVector @@ -7006,8 +7006,8 @@ internal TdsOperationStatus TryReadSqlValueInternal(SqlBuffer value, byte tdsTyp { return result; } - value.SqlBinary = SqlTypeWorkarounds.SqlBinaryCtor(b, true); // doesn't copy the byte array - + + value.SqlBinary = SqlTypeWorkarounds.ByteArrayToSqlBinary(b); break; } @@ -7887,22 +7887,27 @@ internal byte[] SerializeSqlDecimal(SqlDecimal d, TdsParserStateObject stateObj) // sign if (d.IsPositive) + { bytes[current++] = 1; + } else + { bytes[current++] = 0; + } - uint data1, data2, data3, data4; - SqlTypeWorkarounds.SqlDecimalExtractData(d, out data1, out data2, out data3, out data4); - byte[] bytesPart = SerializeUnsignedInt(data1, stateObj); + Span data = stackalloc uint[4]; + SqlTypeWorkarounds.SqlDecimalWriteTdsValue(d, data); + + byte[] bytesPart = SerializeUnsignedInt(data[0], stateObj); Buffer.BlockCopy(bytesPart, 0, bytes, current, 4); current += 4; - bytesPart = SerializeUnsignedInt(data2, stateObj); + bytesPart = SerializeUnsignedInt(data[1], stateObj); Buffer.BlockCopy(bytesPart, 0, bytes, current, 4); current += 4; - bytesPart = SerializeUnsignedInt(data3, stateObj); + bytesPart = SerializeUnsignedInt(data[2], stateObj); Buffer.BlockCopy(bytesPart, 0, bytes, current, 4); current += 4; - bytesPart = SerializeUnsignedInt(data4, stateObj); + bytesPart = SerializeUnsignedInt(data[3], stateObj); Buffer.BlockCopy(bytesPart, 0, bytes, current, 4); return bytes; @@ -7912,16 +7917,21 @@ internal void WriteSqlDecimal(SqlDecimal d, TdsParserStateObject stateObj) { // sign if (d.IsPositive) + { stateObj.WriteByte(1); + } else + { stateObj.WriteByte(0); + } + + Span data = stackalloc uint[4]; + SqlTypeWorkarounds.SqlDecimalWriteTdsValue(d, data); - uint data1, data2, data3, data4; - SqlTypeWorkarounds.SqlDecimalExtractData(d, out data1, out data2, out data3, out data4); - WriteUnsignedInt(data1, stateObj); - WriteUnsignedInt(data2, stateObj); - WriteUnsignedInt(data3, stateObj); - WriteUnsignedInt(data4, stateObj); + WriteUnsignedInt(data[0], stateObj); + WriteUnsignedInt(data[1], stateObj); + WriteUnsignedInt(data[2], stateObj); + WriteUnsignedInt(data[3], stateObj); } private byte[] SerializeDecimal(decimal value, TdsParserStateObject stateObj) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.netfx.cs deleted file mode 100644 index 941e3325b6..0000000000 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.netfx.cs +++ /dev/null @@ -1,343 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Data.SqlTypes; -using System.Reflection; -using System.Reflection.Emit; -using System.Runtime.Serialization; -using Microsoft.Data.SqlClient; - -namespace Microsoft.Data.SqlTypes -{ - /// - /// This type provides workarounds for the separation between System.Data.Common - /// and Microsoft.Data.SqlClient. The latter wants to access internal members of the former, and - /// this class provides ways to do that. We must review and update this implementation any time the - /// implementation of the corresponding types in System.Data.Common change. - /// - internal static partial class SqlTypeWorkarounds - { - #region Work around inability to access SqlMoney.ctor(long, int) and SqlMoney.ToSqlInternalRepresentation - private static readonly Func s_sqlMoneyfactory = CtorHelper.CreateFactory(); // binds to SqlMoney..ctor(long, int) if it exists - - /// - /// Constructs a SqlMoney from a long value without scaling. The ignored parameter exists - /// only to distinguish this constructor from the constructor that takes a long. - /// Used only internally. - /// - internal static SqlMoney SqlMoneyCtor(long value, int ignored) - { - SqlMoney val; - if (s_sqlMoneyfactory is not null) - { - val = s_sqlMoneyfactory(value); - } - else - { - // SqlMoney is a long internally. Dividing by 10,000 gives us the decimal representation - val = new SqlMoney(((decimal)value) / 10000); - } - - return val; - } - - internal static long SqlMoneyToSqlInternalRepresentation(SqlMoney money) - { - return SqlMoneyHelper.s_sqlMoneyToLong(ref money); - } - - private static class SqlMoneyHelper - { - internal delegate long SqlMoneyToLongDelegate(ref SqlMoney @this); - internal static readonly SqlMoneyToLongDelegate s_sqlMoneyToLong = GetSqlMoneyToLong(); - - internal static SqlMoneyToLongDelegate GetSqlMoneyToLong() - { - SqlMoneyToLongDelegate del = null; - try - { - del = GetFastSqlMoneyToLong(); - } - catch - { - // If an exception occurs for any reason, swallow & use the fallback code path. - } - - return del ?? FallbackSqlMoneyToLong; - } - - private static SqlMoneyToLongDelegate GetFastSqlMoneyToLong() - { - MethodInfo toSqlInternalRepresentation = typeof(SqlMoney).GetMethod("ToSqlInternalRepresentation", - BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.ExactBinding, - null, CallingConventions.Any, new Type[] { }, null); - - if (toSqlInternalRepresentation is not null && toSqlInternalRepresentation.ReturnType == typeof(long)) - { - // On Full Framework, invoking the MethodInfo first before wrapping - // a delegate around it will produce better codegen. We don't need - // to inspect the return value; we just need to call the method. - - _ = toSqlInternalRepresentation.Invoke(new SqlMoney(0), new object[0]); - - // Now create the delegate. This is an open delegate, meaning the - // "this" parameter will be provided as arg0 on each call. - - var del = (SqlMoneyToLongDelegate)toSqlInternalRepresentation.CreateDelegate(typeof(SqlMoneyToLongDelegate), target: null); - - // Now we can cache the delegate and invoke it over and over again. - // Note: the first parameter to the delegate is provided *byref*. - - return del; - } - - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.GetFastSqlMoneyToLong | Info | SqlMoney.ToSqlInternalRepresentation() not found. Less efficient fallback method will be used."); - return null; // missing the expected method - cannot use fast path - } - - // Used in case we can't use a [Serializable]-like mechanism. - private static long FallbackSqlMoneyToLong(ref SqlMoney value) - { - if (value.IsNull) - { - return default; - } - else - { - decimal data = value.ToDecimal(); - return (long)(data * 10000); - } - } - } - #endregion - - #region Work around inability to access SqlDecimal._data1/2/3/4 - internal static void SqlDecimalExtractData(SqlDecimal d, out uint data1, out uint data2, out uint data3, out uint data4) - { - SqlDecimalHelper.s_decompose(d, out data1, out data2, out data3, out data4); - } - - private static class SqlDecimalHelper - { - internal delegate void Decomposer(SqlDecimal value, out uint data1, out uint data2, out uint data3, out uint data4); - internal static readonly Decomposer s_decompose = GetDecomposer(); - - private static Decomposer GetDecomposer() - { - Decomposer decomposer = null; - try - { - decomposer = GetFastDecomposer(); - } - catch - { - // If an exception occurs for any reason, swallow & use the fallback code path. - } - - return decomposer ?? FallbackDecomposer; - } - - private static Decomposer GetFastDecomposer() - { - // This takes advantage of the fact that for [Serializable] types, the member fields are implicitly - // part of the type's serialization contract. This includes the fields' names and types. By default, - // [Serializable]-compliant serializers will read all the member fields and shove the data into a - // SerializationInfo dictionary. We mimic this behavior in a manner consistent with the [Serializable] - // pattern, but much more efficiently. - // - // In order to make sure we're staying compliant, we need to gate our checks to fulfill some core - // assumptions. Importantly, the type must be [Serializable] but cannot be ISerializable, as the - // presence of the interface means that the type wants to be responsible for its own serialization, - // and that member fields are not guaranteed to be part of the serialization contract. Additionally, - // we need to check for [OnSerializing] and [OnDeserializing] methods, because we cannot account - // for any logic which might be present within them. - - if (!typeof(SqlDecimal).IsSerializable) - { - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.SqlDecimalHelper.GetFastDecomposer | Info | SqlDecimal isn't Serializable. Less efficient fallback method will be used."); - return null; // type is not serializable - cannot use fast path assumptions - } - - if (typeof(ISerializable).IsAssignableFrom(typeof(SqlDecimal))) - { - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.SqlDecimalHelper.GetFastDecomposer | Info | SqlDecimal is ISerializable. Less efficient fallback method will be used."); - return null; // type contains custom logic - cannot use fast path assumptions - } - - foreach (MethodInfo method in typeof(SqlDecimal).GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) - { - if (method.IsDefined(typeof(OnDeserializingAttribute)) || method.IsDefined(typeof(OnDeserializedAttribute))) - { - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.SqlDecimalHelper.GetFastDecomposer | Info | SqlDecimal contains custom serialization logic. Less efficient fallback method will be used."); - return null; // type contains custom logic - cannot use fast path assumptions - } - } - - // GetSerializableMembers filters out [NonSerialized] fields for us automatically. - - FieldInfo fiData1 = null, fiData2 = null, fiData3 = null, fiData4 = null; - foreach (MemberInfo candidate in FormatterServices.GetSerializableMembers(typeof(SqlDecimal))) - { - if (candidate is FieldInfo fi && fi.FieldType == typeof(uint)) - { - if (fi.Name == "m_data1") - { fiData1 = fi; } - else if (fi.Name == "m_data2") - { fiData2 = fi; } - else if (fi.Name == "m_data3") - { fiData3 = fi; } - else if (fi.Name == "m_data4") - { fiData4 = fi; } - } - } - - if (fiData1 is null || fiData2 is null || fiData3 is null || fiData4 is null) - { - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.SqlDecimalHelper.GetFastDecomposer | Info | Expected SqlDecimal fields are missing. Less efficient fallback method will be used."); - return null; // missing one of the expected member fields - cannot use fast path assumptions - } - - Type refToUInt32 = typeof(uint).MakeByRefType(); - DynamicMethod dm = new( - name: "sqldecimal-decomposer", - returnType: typeof(void), - parameterTypes: new[] { typeof(SqlDecimal), refToUInt32, refToUInt32, refToUInt32, refToUInt32 }, - restrictedSkipVisibility: true); // perf: JITs method at delegate creation time - - ILGenerator ilGen = dm.GetILGenerator(); - ilGen.Emit(OpCodes.Ldarg_1); // eval stack := [UInt32&] - ilGen.Emit(OpCodes.Ldarg_0); // eval stack := [UInt32&] [SqlDecimal] - ilGen.Emit(OpCodes.Ldfld, fiData1); // eval stack := [UInt32&] [UInt32] - ilGen.Emit(OpCodes.Stind_I4); // eval stack := - ilGen.Emit(OpCodes.Ldarg_2); // eval stack := [UInt32&] - ilGen.Emit(OpCodes.Ldarg_0); // eval stack := [UInt32&] [SqlDecimal] - ilGen.Emit(OpCodes.Ldfld, fiData2); // eval stack := [UInt32&] [UInt32] - ilGen.Emit(OpCodes.Stind_I4); // eval stack := - ilGen.Emit(OpCodes.Ldarg_3); // eval stack := [UInt32&] - ilGen.Emit(OpCodes.Ldarg_0); // eval stack := [UInt32&] [SqlDecimal] - ilGen.Emit(OpCodes.Ldfld, fiData3); // eval stack := [UInt32&] [UInt32] - ilGen.Emit(OpCodes.Stind_I4); // eval stack := - ilGen.Emit(OpCodes.Ldarg_S, (byte)4); // eval stack := [UInt32&] - ilGen.Emit(OpCodes.Ldarg_0); // eval stack := [UInt32&] [SqlDecimal] - ilGen.Emit(OpCodes.Ldfld, fiData4); // eval stack := [UInt32&] [UInt32] - ilGen.Emit(OpCodes.Stind_I4); // eval stack := - ilGen.Emit(OpCodes.Ret); - - return (Decomposer)dm.CreateDelegate(typeof(Decomposer), null /* target */); - } - - // Used in case we can't use a [Serializable]-like mechanism. - private static void FallbackDecomposer(SqlDecimal value, out uint data1, out uint data2, out uint data3, out uint data4) - { - if (value.IsNull) - { - data1 = default; - data2 = default; - data3 = default; - data4 = default; - } - else - { - int[] data = value.Data; // allocation - data4 = (uint)data[3]; // write in reverse to avoid multiple bounds checks - data3 = (uint)data[2]; - data2 = (uint)data[1]; - data1 = (uint)data[0]; - } - } - } - #endregion - - #region Work around inability to access SqlBinary.ctor(byte[], bool) - private static readonly Func s_sqlBinaryfactory = CtorHelper.CreateFactory(); // binds to SqlBinary..ctor(byte[], bool) if it exists - - internal static SqlBinary SqlBinaryCtor(byte[] value, bool ignored) - { - SqlBinary val; - if (s_sqlBinaryfactory is not null) - { - val = s_sqlBinaryfactory(value); - } - else - { - val = new SqlBinary(value); - } - - return val; - } - #endregion - - #region Work around inability to access SqlGuid.ctor(byte[], bool) - private static readonly Func s_sqlGuidfactory = CtorHelper.CreateFactory(); // binds to SqlGuid..ctor(byte[], bool) if it exists - - internal static SqlGuid SqlGuidCtor(byte[] value, bool ignored) - { - SqlGuid val; - if (s_sqlGuidfactory is not null) - { - val = s_sqlGuidfactory(value); - } - else - { - val = new SqlGuid(value); - } - - return val; - } - #endregion - - private static class CtorHelper - { - // Returns null if .ctor(TValue, TIgnored) cannot be found. - // Caller should have fallback logic in place in case the API doesn't exist. - internal unsafe static Func CreateFactory() where TInstance : struct - { - try - { - ConstructorInfo fullCtor = typeof(TInstance).GetConstructor( - BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.ExactBinding, - null, new[] { typeof(TValue), typeof(TIgnored) }, null); - if (fullCtor is not null) - { - // Need to use fnptr rather than delegate since MulticastDelegate expects to point to a MethodInfo, - // not a ConstructorInfo. The convention for invoking struct ctors is that the caller zeros memory, - // then passes a ref to the zeroed memory as the implicit arg0 "this". We don't need to worry - // about keeping this pointer alive; the fact that we're instantiated over TInstance will do it - // for us. - // - // On Full Framework, creating a delegate to InvocationHelper before invoking it for the first time - // will cause the delegate to point to the pre-JIT stub, which has an expensive preamble. Instead, - // we invoke InvocationHelper manually with a captured no-op fnptr. We'll then replace it with the - // real fnptr before creating a new delegate (pointing to the real codegen, not the stub) and - // returning that new delegate to our caller. - - static void DummyNoOp(ref TInstance @this, TValue value, TIgnored ignored) - { } - - IntPtr fnPtr; - TInstance InvocationHelper(TValue value) - { - TInstance retVal = default; // ensure zero-inited - ((delegate* managed)fnPtr)(ref retVal, value, default); - return retVal; - } - - fnPtr = (IntPtr)(delegate* managed)(&DummyNoOp); - InvocationHelper(default); // no-op to trigger JIT - - fnPtr = fullCtor.MethodHandle.GetFunctionPointer(); // replace before returning to caller - return InvocationHelper; - } - } - catch - { - } - - SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.CtorHelper.CreateFactory | Info | {0}..ctor({1}, {2}) not found. Less efficient fallback method will be used.", typeof(TInstance).Name, typeof(TValue).Name, typeof(TIgnored).Name); - return null; // factory not found or an exception occurred - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/src/Microsoft.Data.SqlClient.csproj index ff8274c66f..aef9ceeb7e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/src/Microsoft.Data.SqlClient.csproj @@ -20,6 +20,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/ValueUtilsSmi.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/ValueUtilsSmi.cs index 39b2494fa0..3369ce49fb 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/ValueUtilsSmi.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/ValueUtilsSmi.cs @@ -2965,11 +2965,12 @@ private static SqlMoney GetSqlMoney_Unchecked(ITypedGettersV3 getters, int ordin Debug.Assert(!IsDBNull_Unchecked(getters, ordinal)); long temp = getters.GetInt64(ordinal); -#if NET + + #if NET return SqlMoney.FromTdsValue(temp); -#else - return SqlTypeWorkarounds.SqlMoneyCtor(temp, 1 /* ignored */ ); -#endif + #else + return SqlTypeWorkarounds.LongToSqlMoney(temp); + #endif } private static SqlXml GetSqlXml_Unchecked(ITypedGettersV3 getters, int ordinal) @@ -3395,11 +3396,13 @@ private static void SetSqlMoney_Unchecked(ITypedSettersV3 setters, int ordinal, setters.SetVariantMetaData(ordinal, SmiMetaData.DefaultMoney); } -#if NET - setters.SetInt64(ordinal, value.GetTdsValue()); -#else - setters.SetInt64(ordinal, SqlTypeWorkarounds.SqlMoneyToSqlInternalRepresentation(value)); -#endif + #if NET + long longValue = value.GetTdsValue(); + #else + long longValue = SqlTypeWorkarounds.SqlMoneyToLong(value); + #endif + + setters.SetInt64(ordinal, longValue); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs index 39d2758d62..80fdddad57 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs @@ -918,11 +918,12 @@ internal SqlMoney SqlMoney { return SqlMoney.Null; } -#if NET + + #if NET return SqlMoney.FromTdsValue(_value._int64); -#else - return SqlTypeWorkarounds.SqlMoneyCtor(_value._int64, 1/*ignored*/); -#endif + #else + return SqlTypeWorkarounds.LongToSqlMoney(_value._int64); + #endif } return (SqlMoney)SqlValue; // anything else we haven't thought of goes through boxing. } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs index 853be887dc..1c825bbaef 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs @@ -10,6 +10,11 @@ using System.Xml; using Microsoft.Data.SqlClient; +#if NETFRAMEWORK +using System.Reflection; +using System.Runtime.InteropServices; +#endif + namespace Microsoft.Data.SqlTypes { /// @@ -18,39 +23,35 @@ namespace Microsoft.Data.SqlTypes /// this class provides ways to do that. We must review and update this implementation any time the /// implementation of the corresponding types in System.Data.Common change. /// - internal static partial class SqlTypeWorkarounds + internal static class SqlTypeWorkarounds { #region Work around inability to access SqlXml.CreateSqlXmlReader + private static readonly XmlReaderSettings s_defaultXmlReaderSettings = new() { ConformanceLevel = ConformanceLevel.Fragment }; private static readonly XmlReaderSettings s_defaultXmlReaderSettingsCloseInput = new() { ConformanceLevel = ConformanceLevel.Fragment, CloseInput = true }; private static readonly XmlReaderSettings s_defaultXmlReaderSettingsAsyncCloseInput = new() { Async = true, ConformanceLevel = ConformanceLevel.Fragment, CloseInput = true }; internal const SqlCompareOptions SqlStringValidSqlCompareOptionMask = - SqlCompareOptions.IgnoreCase | SqlCompareOptions.IgnoreWidth | - SqlCompareOptions.IgnoreNonSpace | SqlCompareOptions.IgnoreKanaType | - SqlCompareOptions.BinarySort | SqlCompareOptions.BinarySort2; + SqlCompareOptions.BinarySort | + SqlCompareOptions.BinarySort2 | + SqlCompareOptions.IgnoreCase | + SqlCompareOptions.IgnoreWidth | + SqlCompareOptions.IgnoreNonSpace | + SqlCompareOptions.IgnoreKanaType; internal static XmlReader SqlXmlCreateSqlXmlReader(Stream stream, bool closeInput, bool async) { Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async"); - XmlReaderSettings settingsToUse = closeInput ? - (async ? s_defaultXmlReaderSettingsAsyncCloseInput : s_defaultXmlReaderSettingsCloseInput) : - s_defaultXmlReaderSettings; + XmlReaderSettings settingsToUse = closeInput + ? async + ? s_defaultXmlReaderSettingsAsyncCloseInput + : s_defaultXmlReaderSettingsCloseInput + : s_defaultXmlReaderSettings; return XmlReader.Create(stream, settingsToUse); } - - internal static XmlReader SqlXmlCreateSqlXmlReader(TextReader textReader, bool closeInput, bool async) - { - Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async"); - - XmlReaderSettings settingsToUse = closeInput ? - (async ? s_defaultXmlReaderSettingsAsyncCloseInput : s_defaultXmlReaderSettingsCloseInput) : - s_defaultXmlReaderSettings; - - return XmlReader.Create(textReader, settingsToUse); - } + #endregion #region Work around inability to access SqlDateTime.ToDateTime @@ -90,5 +91,192 @@ internal static DateTime SqlDateTimeToDateTime(int daypart, int timepart) private static Exception ThrowOverflowException() => throw SQL.DateTimeOverflow(); #endregion + + #if NETFRAMEWORK + + #region Work around inability to access `new SqlBinary(byte[], bool)` + + // Documentation of internal constructor: + // https://learn.microsoft.com/en-us/dotnet/framework/additional-apis/system.data.sqltypes.sqlbinary.-ctor + private static readonly Func ByteArrayToSqlBinaryFactory = + CreateFactory(value => new SqlBinary(value)); + + internal static SqlBinary ByteArrayToSqlBinary(byte[] value) => + ByteArrayToSqlBinaryFactory(value); + + #endregion + + #region Work around SqlDecimal.WriteTdsValue not existing in netfx + + /// + /// Implementation that mimics netcore's WriteTdsValue method. + /// + /// + /// Although calls to this method could just be replaced with calls to + /// , using this mimic method allows netfx and netcore + /// implementations to be more cleanly switched. + /// + /// SqlDecimal value to get data from. + /// Span to write data to. + internal static void SqlDecimalWriteTdsValue(SqlDecimal value, Span outSpan) + { + // Note: Although it would be faster to use the m_data[1-4] member variables in + // SqlDecimal, we cannot use them because they are not documented. The Data property + // is less ideal, but is documented. + Debug.Assert(outSpan.Length == 4, "Output span must be 4 elements long."); + + int[] data = value.Data; + outSpan[0] = (uint)data[0]; + outSpan[1] = (uint)data[1]; + outSpan[2] = (uint)data[2]; + outSpan[3] = (uint)data[3]; + } + + #endregion + + #region Work around inability to access `new SqlGuid(byte[], bool)` + + // Documentation for internal constructor: + // https://learn.microsoft.com/en-us/dotnet/framework/additional-apis/system.data.sqltypes.sqlguid.-ctor + private static readonly Func ByteArrayToSqlGuidFactory = + CreateFactory(value => new SqlGuid(value)); + + internal static SqlGuid ByteArrayToSqlGuid(byte[] value) => + ByteArrayToSqlGuidFactory(value); + + #endregion + + #region Work around inability to access `new SqlMoney(long, int)` and `SqlMoney.ToInternalRepresentation()` + + // Documentation for internal ctor: + // https://learn.microsoft.com/en-us/dotnet/framework/additional-apis/system.data.sqltypes.sqlmoney.-ctor + private static readonly Func LongToSqlMoneyFactory = + CreateFactory(value => new SqlMoney((decimal)value / 10000)); + + private delegate long SqlMoneyToLongDelegate(ref SqlMoney @this); + private static readonly SqlMoneyToLongDelegate SqlMoneyToLongFactory = + CreateSqlMoneyToLongFactory(); + + /// + /// Constructs a SqlMoney from a long value without scaling. + /// + /// Internal representation of SqlMoney value. + internal static SqlMoney LongToSqlMoney(long value) => + LongToSqlMoneyFactory(value); + + /// + /// Deconstructs a SqlMoney into a long value with scaling. + /// + /// SqlMoney value + internal static long SqlMoneyToLong(SqlMoney value) => + SqlMoneyToLongFactory(ref value); + + private static SqlMoneyToLongDelegate CreateSqlMoneyToLongFactory() + { + try + { + // Note: Although it would be faster to use the m_value member variable in + // SqlMoney, but because it is not documented, we cannot use it. The method + // we are calling below *is* documented, despite it being internal. + // Documentation for internal method: + // https://learn.microsoft.com/en-us/dotnet/framework/additional-apis/system.data.sqltypes.sqlmoney.tosqlinternalrepresentation + + MethodInfo method = typeof(SqlMoney).GetMethod( + "ToSqlInternalRepresentation", + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.ExactBinding, + binder: null, + types: Array.Empty(), + modifiers: null); + + if (method is not null && method.ReturnType == typeof(long)) + { + // Force warming up the JIT by calling it once. Allegedly doing this *before* + // wrapping in a delegate will give better codegen. + // Note: We must use something other than default since this cannot be used on + // Null SqlMoney structs. + _ = method.Invoke(SqlMoney.Zero, Array.Empty()); + + // Create a delegate for the method. This will be an "open" delegate, meaning + // the instance to call the method on will be provided as arg0 on each call. + // Note the first parameter to the delegate is provided *by reference*. + var del = (SqlMoneyToLongDelegate)method.CreateDelegate(typeof(SqlMoneyToLongDelegate), target: null); + + return del; + } + } + catch + { + // Reflection failed, fall through to using conversion via decimal + } + + // @TODO: SqlMoney.ToSqlInternalRepresentation will throw on SqlMoney.IsNull, the fallback will not. + SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.CreateSqlMoneyToLongFactory | Info | SqlMoney.ToInternalRepresentation(SqlMoney) not found. Less efficient fallback method will be used."); + return (ref SqlMoney value) => value.IsNull ? 0 : (long)(value.ToDecimal() * 10000); + } + + #endregion + + private static unsafe Func CreateFactory( + Func fallbackFactory) + where TInstance : struct + { + // The logic of this method is that there are special internal methods that can create + // Sql* types without the need for copying. These methods are internal to System.Data, + // so we cannot access them, even they are so much faster. To get around this, we + // take a small perf hit to discover them via reflection in exchange for the faster + // perf. If reflection fails, we fall back and use the publicly available ctor, but + // it will be much slower. + // The TIgnored type is an extra argument to the ctor that differentiates this internal + // ctor from the public ctor. + + try + { + // Look for TInstance constructor that takes TValue, TIgnored + ConstructorInfo ctor = typeof(TInstance).GetConstructor( + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, + binder: null, + types: new[] { typeof(TValue), typeof(TIgnored) }, + modifiers: null); + + if (ctor is not null) + { + // Use function pointer for maximum performance on repeated calls. + // This avoids delegate allocation overhead and is nearly as fast as direct + // calls to the constructor + IntPtr fnPtr; + + TInstance FastFactory(TValue value) + { + TInstance result = default; + ((delegate* managed)fnPtr)( + ref result, + value, + default /*ignored*/); + return result; + } + + // Force JIT compilation with a dummy function pointer first + static void DummyNoOp(ref TInstance @this, TValue value, TIgnored ignored) { } + fnPtr = (IntPtr)(delegate* managed)(&DummyNoOp); + FastFactory(default); + + // Replace with real constructor function pointer + fnPtr = ctor.MethodHandle.GetFunctionPointer(); + return FastFactory; + } + } + catch + { + // Reflection failed, fall through to use the slow conversion. + } + + // If reflection failed, or the ctor couldn't be found, fallback to construction using + // the fallback factory. This will be much slower, but ensures conversion can still + // happen. + SqlClientEventSource.Log.TryTraceEvent("SqlTypeWorkarounds.CreateFactory | Info | {0}..ctor({1}, {2}) not found. Less efficient fallback method will be used.", typeof(TInstance).Name, typeof(TValue).Name, typeof(TIgnored).Name); + return fallbackFactory; + } + + #endif } } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlTypes/SqlTypeWorkaroundsTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlTypes/SqlTypeWorkaroundsTests.cs new file mode 100644 index 0000000000..f7cd1811ed --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlTypes/SqlTypeWorkaroundsTests.cs @@ -0,0 +1,206 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data.SqlTypes; +using Microsoft.Data.SqlTypes; +using Xunit; + +#nullable enable + +namespace Microsoft.Data.SqlClient.UnitTests +{ + public class SqlTypeWorkaroundsTests + { + // @TODO: Need a facade pattern for Type getting so we can test the case where reflection fails + + #if NETFRAMEWORK + + #region SqlBinary + + public static TheoryData ByteArrayToSqlBinary_NonNullInput_Data => + new TheoryData + { + Array.Empty(), + new byte[] { 1, 2, 3, 4}, + }; + + [Theory] + [MemberData(nameof(ByteArrayToSqlBinary_NonNullInput_Data))] + public void ByteArrayToSqlBinary_NonNullInput(byte[] input) + { + // Act + SqlBinary result = SqlTypeWorkarounds.ByteArrayToSqlBinary(input); + + // Assert + Assert.False(result.IsNull); + Assert.Equal(input, result.Value); + } + + [Fact] + public void ByteArrayToSqlBinary_NullInput() + { + // Act + SqlBinary result = SqlTypeWorkarounds.ByteArrayToSqlBinary(null); + + // Assert + Assert.True(result.IsNull); + } + + #endregion + + #region SqlDecimal + + public static TheoryData SqlDecimalWriteTdsValue_NonNullInput_Data => + new TheoryData + { + SqlDecimal.MinValue, + new SqlDecimal(-1.2345678), + new SqlDecimal(0), + new SqlDecimal(1.2345678), + SqlDecimal.MaxValue, + }; + + [Theory] + [MemberData(nameof(SqlDecimalWriteTdsValue_NonNullInput_Data))] + public void SqlDecimalWriteTdsValue_NonNullInput(SqlDecimal input) + { + // Arrange + Span output = stackalloc uint[4]; + + // Act + SqlTypeWorkarounds.SqlDecimalWriteTdsValue(input, output); + + // Assert + int[] expected = input.Data; + Assert.Equal(expected[0], (int)output[0]); + Assert.Equal(expected[1], (int)output[1]); + Assert.Equal(expected[2], (int)output[2]); + Assert.Equal(expected[3], (int)output[3]); + } + + [Fact] + public void SqlDecimalWriteTdsValue_NullInput() + { + Action action = () => + { + // Arrange + SqlDecimal input = SqlDecimal.Null; + Span output = stackalloc uint[4]; + + // Act + SqlTypeWorkarounds.SqlDecimalWriteTdsValue(input, output); + }; + + // Assert + Assert.Throws(action); + } + + #endregion + + #region SqlGuid + + public static TheoryData ByteArrayToSqlGuid_InvalidInput_Data => + new TheoryData + { + null, + Array.Empty(), + new byte[] { 1, 2, 3, 4 }, // Too short + new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 } // Too long + }; + + [Theory] + [MemberData(nameof(ByteArrayToSqlGuid_InvalidInput_Data))] + public void ByteArrayToSqlGuid_InvalidInput(byte[]? input) + { + // Act + Action action = () => SqlTypeWorkarounds.ByteArrayToSqlGuid(input); + + // Assert + Assert.Throws(action); + } + + public static TheoryData ByteArrayToSqlGuid_ValidInput_Data => + new TheoryData + { + new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 } + }; + + [Theory] + [MemberData(nameof(ByteArrayToSqlGuid_ValidInput_Data))] + public void ByteArrayToSqlGuid_ValidInput(byte[] input) + { + // Act + SqlGuid result = SqlTypeWorkarounds.ByteArrayToSqlGuid(input); + + // Assert + Assert.False(result.IsNull); + Assert.Equal(input, result.Value.ToByteArray()); + } + + #endregion + + #region SqlMoney + + public static TheoryData LongToSqlMoney_Data => + new TheoryData + { + { long.MinValue, SqlMoney.MinValue }, + { (long)((decimal)-123000000 / 10000), new SqlMoney(-1.23) }, + { 0, SqlMoney.Zero }, + { (long)((decimal)123000000 / 10000), new SqlMoney(1.23) }, + { long.MaxValue, SqlMoney.MaxValue }, + }; + + [Theory] + [MemberData(nameof(LongToSqlMoney_Data))] + public void LongToSqlMoney(long input, SqlMoney expected) + { + // Act + SqlMoney result = SqlTypeWorkarounds.LongToSqlMoney(input); + + // Assert + Assert.Equal(expected, result); + } + + public static TheoryData SqlMoneyToLong_NonNullInput_Data => + new TheoryData + { + { SqlMoney.MinValue, long.MinValue }, + { new SqlMoney(-1.23), (long)(new SqlMoney(-1.23).ToDecimal() * 10000) }, + { SqlMoney.Zero, 0 }, + { new SqlMoney(1.23), (long)(new SqlMoney(1.23).ToDecimal() * 10000) }, + { SqlMoney.MaxValue, long.MaxValue }, + }; + + [Theory] + [MemberData(nameof(SqlMoneyToLong_NonNullInput_Data))] + public void SqlMoneyToLong_NonNullInput(SqlMoney input, long expected) + { + // Act + long result = SqlTypeWorkarounds.SqlMoneyToLong(input); + + // Assert + Assert.Equal(expected, result); + } + + [Fact] + public void SqlMoneyToLong_NullInput() + { + // Arrange + SqlMoney input = SqlMoney.Null; + + // Act + Action action = () => SqlTypeWorkarounds.SqlMoneyToLong(input); + + // Assert + Assert.Throws(action); + } + + #endregion + + #endif + } +}