Skip to content

Commit ee1d2f0

Browse files
ExanitePerksey
andauthored
Fix behavior of SilkMarshal.StringToPtr and related methods on Linux (#2377)
* Add Silk.NET.Core.Tests project * Add test cases for testing string encoding * Fix issue where PtrToStringArray was ignoring the encoding parameter * Add test case for testing LPWStr char width * Use 4-byte LPWStrs on non-Windows platforms * Use Encoding.UTF32 for PtrToString * Simplify use of "when not Windows" clauses * Also use Encoding.UTF32 for StringIntoSpan * Update test cases and docs to show that LPWStr is UTF-32 on non-Windows See #2377 * Don't test non-ascii characters in TestEncodingString/Array This is because LPStr doesn't always support non-ascii characters. --------- Co-authored-by: Dylan Perks <11160611+Perksey@users.noreply.github.com>
1 parent 52f00d3 commit ee1d2f0

File tree

3 files changed

+150
-30
lines changed

3 files changed

+150
-30
lines changed

src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
15
using System.Collections.Generic;
6+
using System.Runtime.InteropServices;
7+
using System.Text;
28
using Silk.NET.Core.Native;
39
using Xunit;
410

@@ -15,6 +21,44 @@ public class TestSilkMarshal
1521
NativeStringEncoding.LPWStr,
1622
};
1723

24+
private readonly Encoding lpwStrEncoding = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
25+
? Encoding.Unicode
26+
: Encoding.UTF32;
27+
28+
private readonly int lpwStrCharacterWidth = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 2 : 4;
29+
30+
[Fact]
31+
public unsafe void TestEncodingToLPWStr()
32+
{
33+
var input = "Hello world 🧵";
34+
35+
var expectedByteCount = lpwStrEncoding.GetByteCount(input);
36+
var expected = new byte[expectedByteCount + lpwStrCharacterWidth];
37+
lpwStrEncoding.GetBytes(input, expected);
38+
39+
var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr);
40+
var pointerByteCount = lpwStrCharacterWidth * (int) SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr);
41+
42+
Assert.Equal(expected, new Span<byte>((void*)pointer, pointerByteCount + lpwStrCharacterWidth));
43+
}
44+
45+
[Fact]
46+
public unsafe void TestEncodingFromLPWStr()
47+
{
48+
var expected = "Hello world 🧵";
49+
50+
var inputByteCount = lpwStrEncoding.GetByteCount(expected);
51+
var input = new byte[inputByteCount + lpwStrCharacterWidth];
52+
lpwStrEncoding.GetBytes(expected, input);
53+
54+
fixed (byte* pInput = input)
55+
{
56+
var output = SilkMarshal.PtrToString((nint)pInput, NativeStringEncoding.LPWStr);
57+
58+
Assert.Equal(expected, output);
59+
}
60+
}
61+
1862
[Fact]
1963
public void TestEncodingString()
2064
{

src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ public enum NativeStringEncoding
99
LPStr = UnmanagedType.LPStr,
1010
LPTStr = UnmanagedType.LPTStr,
1111
LPUTF8Str = UnmanagedType.LPUTF8Str,
12+
/// <summary>
13+
/// On Windows, a null-terminated UTF-16 string. On other platforms, a null-terminated UTF-32 string.
14+
/// </summary>
1215
LPWStr = UnmanagedType.LPWStr,
1316
WinString = UnmanagedType.WinString,
1417
Ansi = LPStr,

src/Core/Silk.NET.Core/Native/SilkMarshal.cs

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
144144
NativeStringEncoding.BStr => -1,
145145
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
146146
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
147-
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
147+
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
148+
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 4,
148149
_ => -1
149150
};
150151

@@ -188,29 +189,38 @@ public static unsafe int StringIntoSpan
188189
int convertedBytes;
189190

190191
fixed (char* firstChar = input)
192+
fixed (byte* bytes = span)
191193
{
192-
fixed (byte* bytes = span)
193-
{
194-
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
195-
}
194+
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
195+
bytes[convertedBytes] = 0;
196196
}
197197

198-
span[convertedBytes] = 0;
199-
return ++convertedBytes;
198+
return convertedBytes + 1;
200199
}
201-
case NativeStringEncoding.LPWStr:
200+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
202201
{
203202
fixed (char* firstChar = input)
203+
fixed (byte* bytes = span)
204204
{
205-
fixed (byte* bytes = span)
206-
{
207-
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
208-
((char*)bytes)[input.Length] = default;
209-
}
205+
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
206+
((char*)bytes)[input.Length] = default;
210207
}
211208

212209
return input.Length + 1;
213210
}
211+
case NativeStringEncoding.LPWStr:
212+
{
213+
int convertedBytes;
214+
215+
fixed (char* firstChar = input)
216+
fixed (byte* bytes = span)
217+
{
218+
convertedBytes = Encoding.UTF32.GetBytes(firstChar, input.Length, bytes, span.Length - 4);
219+
((uint*)bytes)[convertedBytes / 4] = 0;
220+
}
221+
222+
return convertedBytes + 4;
223+
}
214224
default:
215225
{
216226
ThrowInvalidEncoding<GlobalMemory>();
@@ -311,7 +321,19 @@ static unsafe string BStrToString(nint ptr)
311321
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));
312322

313323
static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
314-
static unsafe string WideToString(nint ptr) => new string((char*) ptr);
324+
325+
static unsafe string WideToString(nint ptr)
326+
{
327+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
328+
{
329+
return new string((char*) ptr);
330+
}
331+
else
332+
{
333+
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
334+
return Encoding.UTF32.GetString((byte*) ptr, 4 * (int) length);
335+
}
336+
};
315337
}
316338

317339
/// <summary>
@@ -524,15 +546,41 @@ Func<nint, string> customUnmarshaller
524546
/// </remarks>
525547
#if NET6_0_OR_GREATER
526548
[MethodImpl(MethodImplOptions.AggressiveInlining)]
527-
public static unsafe nuint StringLength(
549+
public static unsafe nuint StringLength
550+
(
528551
nint ptr,
529552
NativeStringEncoding encoding = NativeStringEncoding.Ansi
530-
) =>
531-
(nuint)(
532-
encoding == NativeStringEncoding.LPWStr
533-
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
534-
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
535-
);
553+
)
554+
{
555+
switch (encoding)
556+
{
557+
default:
558+
{
559+
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
560+
}
561+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
562+
{
563+
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
564+
}
565+
case NativeStringEncoding.LPWStr:
566+
{
567+
// No int overload for CreateReadOnlySpanFromNullTerminated
568+
if (ptr == 0)
569+
{
570+
return 0;
571+
}
572+
573+
nuint length = 0;
574+
while (((uint*) ptr)![length] != 0)
575+
{
576+
length++;
577+
}
578+
579+
return length;
580+
}
581+
}
582+
}
583+
536584
#else
537585
public static unsafe nuint StringLength(
538586
nint ptr,
@@ -543,15 +591,40 @@ public static unsafe nuint StringLength(
543591
{
544592
return 0;
545593
}
546-
nuint ret;
547-
for (
548-
ret = 0;
549-
encoding == NativeStringEncoding.LPWStr
550-
? ((char*)ptr)![ret] != 0
551-
: ((byte*)ptr)![ret] != 0;
552-
ret++
553-
) { }
554-
return ret;
594+
595+
nuint length = 0;
596+
switch (encoding)
597+
{
598+
default:
599+
{
600+
while (((byte*) ptr)![length] != 0)
601+
{
602+
length++;
603+
}
604+
605+
break;
606+
}
607+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
608+
{
609+
while (((char*) ptr)![length] != 0)
610+
{
611+
length++;
612+
}
613+
614+
break;
615+
}
616+
case NativeStringEncoding.LPWStr:
617+
{
618+
while (((uint*) ptr)![length] != 0)
619+
{
620+
length++;
621+
}
622+
623+
break;
624+
}
625+
}
626+
627+
return length;
555628
}
556629
#endif
557630

0 commit comments

Comments
 (0)