Skip to content

Commit cb5fabc

Browse files
committed
Added advanced function generation.
1 parent ff565c6 commit cb5fabc

File tree

8 files changed

+226
-28
lines changed

8 files changed

+226
-28
lines changed

HexaGen.Core/CSharp/CsFunctionOverload.cs

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{
33
using HexaGen.Core.Collections;
44
using System.Collections.Generic;
5+
using System.Diagnostics.CodeAnalysis;
56
using System.Text.Json.Serialization;
67

78
public enum CsFunctionKind
@@ -30,6 +31,10 @@ public CsFunctionOverload(string exportedName, string name, string? comment, Dic
3031
Variations = new(variations);
3132
Modifiers = modifiers;
3233
Attributes = attributes;
34+
for (int i = 0; i < variations.Count; i++)
35+
{
36+
ValueVariations.Add(variations[i]);
37+
}
3338
}
3439

3540
public CsFunctionOverload(string exportedName, string name, string? comment, string structName, CsFunctionKind kind, CsType returnType)
@@ -65,6 +70,9 @@ public CsFunctionOverload(string exportedName, string name, string? comment, str
6570

6671
public ConcurrentList<CsFunctionVariation> Variations { get; set; }
6772

73+
[JsonIgnore]
74+
public HashSet<ValueVariation> ValueVariations { get; set; } = [];
75+
6876
public List<string> Modifiers { get; set; }
6977

7078
public List<string> Attributes { get; set; }
@@ -73,44 +81,44 @@ public bool HasVariation(CsFunctionVariation variation)
7381
{
7482
lock (Variations.SyncObject)
7583
{
76-
for (int i = 0; i < Variations.Count; i++)
77-
{
78-
var iation = Variations[i];
79-
if (variation.Parameters.Count != iation.Parameters.Count)
80-
continue;
81-
if (variation.Name != iation.Name)
82-
continue;
83-
84-
bool skip = false;
85-
for (int j = 0; j < iation.Parameters.Count; j++)
86-
{
87-
if (variation.Parameters[j].Type.Name != iation.Parameters[j].Type.Name || variation.Parameters[j].DefaultValue != iation.Parameters[j].DefaultValue)
88-
{
89-
skip = true;
90-
break;
91-
}
92-
}
93-
94-
if (skip)
95-
continue;
84+
return ValueVariations.Contains(variation);
85+
}
86+
}
87+
88+
public bool HasVariation(ValueVariation variation)
89+
{
90+
lock (Variations.SyncObject)
91+
{
92+
return ValueVariations.Contains(variation);
93+
}
94+
}
9695

96+
public bool TryAddVariation(CsFunctionVariation variation)
97+
{
98+
lock (Variations.SyncObject)
99+
{
100+
if (ValueVariations.Add(variation))
101+
{
102+
Variations.Add(variation);
97103
return true;
98104
}
99-
100-
return false;
101105
}
106+
return false;
102107
}
103108

104-
public bool TryAddVariation(CsFunctionVariation variation)
109+
public bool TryAddVariation(ValueVariation valueVariation, [NotNullWhen(true)] out CsFunctionVariation? variation)
105110
{
106111
lock (Variations.SyncObject)
107112
{
108-
if (!HasVariation(variation))
113+
if (ValueVariations.Add(valueVariation))
109114
{
115+
variation = CreateVariationWith();
116+
variation.Parameters.AddRange(valueVariation.Parameters);
110117
Variations.Add(variation);
111118
return true;
112119
}
113120
}
121+
variation = null;
114122
return false;
115123
}
116124

@@ -122,6 +130,9 @@ public bool TryUpdateVariation(CsFunctionVariation oldVariation, CsFunctionVaria
122130
{
123131
Variations.Add(newVariation);
124132
Variations.Remove(oldVariation);
133+
134+
ValueVariations.Remove(oldVariation);
135+
ValueVariations.Add(newVariation);
125136
return true;
126137
}
127138
}

HexaGen.Core/CSharp/CsFunctionVariation.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,5 +293,10 @@ public CsFunctionVariation Clone()
293293
{
294294
return new CsFunctionVariation(Identifier, ExportedName, Name, StructName, Kind, ReturnType.Clone(), Parameters.CloneValues(), GenericParameters.CloneValues(), Modifiers.Clone(), Attributes.Clone());
295295
}
296+
297+
public static implicit operator ValueVariation(CsFunctionVariation variation)
298+
{
299+
return new ValueVariation(variation.Name, variation.Parameters);
300+
}
296301
}
297302
}

HexaGen.Core/CSharp/ValueVariation.cs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
namespace HexaGen.Core.CSharp
2+
{
3+
using System.Collections.Generic;
4+
5+
public readonly struct ValueVariation : IEquatable<ValueVariation>
6+
{
7+
private readonly string name;
8+
private readonly IList<CsParameterInfo> parameters;
9+
10+
public ValueVariation(string name, IList<CsParameterInfo> parameters)
11+
{
12+
this.name = name;
13+
this.parameters = parameters;
14+
}
15+
16+
public readonly string Name => name;
17+
18+
public readonly IList<CsParameterInfo> Parameters => parameters;
19+
20+
public override readonly bool Equals(object? obj)
21+
{
22+
return obj is ValueVariation variation && Equals(variation);
23+
}
24+
25+
public readonly bool Equals(ValueVariation other)
26+
{
27+
if (other.parameters.Count != parameters.Count) return false;
28+
if (other.name != name) return false;
29+
for (int i = 0; i < parameters.Count; i++)
30+
{
31+
if (other.parameters[i].Type.Name != parameters[i].Type.Name || other.parameters[i].DefaultValue != parameters[i].DefaultValue)
32+
{
33+
return false;
34+
}
35+
}
36+
37+
return true;
38+
}
39+
40+
public override readonly int GetHashCode()
41+
{
42+
HashCode code = new();
43+
code.Add(name);
44+
foreach (var parameter in parameters)
45+
{
46+
code.Add(parameter.Type.Name);
47+
}
48+
return code.ToHashCode();
49+
}
50+
51+
public static bool operator ==(ValueVariation left, ValueVariation right)
52+
{
53+
return left.Equals(right);
54+
}
55+
56+
public static bool operator !=(ValueVariation left, ValueVariation right)
57+
{
58+
return !(left == right);
59+
}
60+
}
61+
}

HexaGen.Core/HexaGen.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
88

99
<AssemblyVersion>1.1.1</AssemblyVersion>
10-
<PackageVersion>1.1.3</PackageVersion>
10+
<PackageVersion>1.1.4</PackageVersion>
1111
<Description></Description>
1212
<PackageTags></PackageTags>
1313
<Authors>Juna Meinhold</Authors>

HexaGen/CsCodeGeneratorConfig.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ public partial class CsCodeGeneratorConfig : IGeneratorConfig
257257
{
258258
"HWND",
259259
"nint"
260-
}
260+
},
261+
VaryingTypes = ["ReadOnlySpan<byte>", "string", "ref string"]
261262
};
262263

263264
public static CsCodeGeneratorConfig Load(string file)
@@ -287,6 +288,11 @@ public static CsCodeGeneratorConfig Load(string file)
287288
result.IgnoredTypedefs.Add(item);
288289
}
289290

291+
foreach (var item in Default.VaryingTypes)
292+
{
293+
result.VaryingTypes.Add(item);
294+
}
295+
290296
if (!result.EnableExperimentalOptions)
291297
{
292298
}
@@ -673,6 +679,16 @@ public static CsCodeGeneratorConfig Load(string file)
673679

674680
public readonly List<CsEnumMetadata> CustomEnums = [];
675681

682+
/// <summary>
683+
/// A list of allowed types for generating additional overloads.
684+
/// </summary>
685+
public HashSet<string> VaryingTypes { get; set; } = new();
686+
687+
/// <summary>
688+
/// Generates additional overloads, <c>WARNING</c> this option can really generate many overloads. To filter which type is allowed use <see cref="VaryingTypes"/>
689+
/// </summary>
690+
public bool GenerateAdditionalOverloads { get; set; }
691+
676692
public void Save(string path)
677693
{
678694
File.WriteAllText(path, JsonSerializer.Serialize(this, new JsonSerializerOptions()

HexaGen/Extensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using System.Runtime.InteropServices;
88
using System.Text;
99

10-
internal static class Extensions
10+
public static class Extensions
1111
{
1212
public static string FormatLocationAttribute(this CppElement element)
1313
{

HexaGen/FunctionGeneration/FunctionGenerator.cs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{
33
using CppAst;
44
using HexaGen;
5+
using HexaGen.Core;
56
using HexaGen.Core.CSharp;
67
using System;
78
using System.Collections.Generic;
@@ -252,6 +253,14 @@ public virtual void GenerateVariations(IList<CppParameter> parameters, CsFunctio
252253
}
253254
}
254255

256+
foreach (var valueVariation in GenerateAdditionalOverloads(function.Name, parameterLists))
257+
{
258+
if (function.TryAddVariation(valueVariation, out var variation))
259+
{
260+
ApplySteps(function, variation);
261+
}
262+
}
263+
255264
if (customParameterList != null)
256265
{
257266
for (int i = 0; i < customParameterList.Length; i++)
@@ -267,6 +276,102 @@ public virtual void GenerateVariations(IList<CppParameter> parameters, CsFunctio
267276
}
268277
}
269278

279+
private static (int, int) FindDifferenceRange(CsParameterInfo[][] overloads)
280+
{
281+
int start = -1;
282+
int end = -1;
283+
for (int i = 0; i < overloads[0].Length; i++)
284+
{
285+
bool found = false;
286+
for (int j = 1; j < overloads.Length; j++)
287+
{
288+
var overload = overloads[j];
289+
if (overloads[0][i].Type.Name != overload[i].Type.Name || overloads[0][i].DefaultValue != overload[i].DefaultValue)
290+
{
291+
found = true;
292+
break;
293+
}
294+
}
295+
296+
if (found)
297+
{
298+
if (start == -1)
299+
{
300+
start = i;
301+
}
302+
end = i + 1;
303+
}
304+
}
305+
306+
return (start, end);
307+
}
308+
309+
private IEnumerable<ValueVariation> GenerateAdditionalOverloads(string name, CsParameterInfo[][] overloads)
310+
{
311+
var (start, end) = FindDifferenceRange(overloads);
312+
313+
if (start == -1)
314+
{
315+
return Array.Empty<ValueVariation>();
316+
}
317+
318+
HashSet<ValueVariation> variations = [];
319+
320+
foreach (var originalOverload in overloads)
321+
{
322+
variations.Add(new ValueVariation(name, originalOverload));
323+
}
324+
325+
var baseList = new List<CsParameterInfo>();
326+
baseList.AddRange(overloads[0].AsSpan(0, start));
327+
328+
return GenerateCombinations(name, overloads, variations, baseList, start, end);
329+
}
330+
331+
private IEnumerable<ValueVariation> GenerateCombinations(string name, CsParameterInfo[][] overloads, HashSet<ValueVariation> variations, List<CsParameterInfo> current, int depth, int end)
332+
{
333+
if (depth == end)
334+
{
335+
int delta = overloads[0].Length - end;
336+
if (delta != 0)
337+
{
338+
current.AddRange(overloads[0].AsSpan(end, delta));
339+
}
340+
341+
var variation = new ValueVariation(name, current);
342+
if (!variations.Contains(variation))
343+
{
344+
var clone = current.Clone();
345+
var newVariation = new ValueVariation(name, clone);
346+
yield return newVariation;
347+
variations.Add(newVariation);
348+
}
349+
350+
if (delta != 0)
351+
{
352+
current.RemoveRange(end, delta);
353+
}
354+
355+
yield break;
356+
}
357+
358+
for (int i = 0; i < overloads.Length; i++)
359+
{
360+
var param = overloads[i][depth];
361+
if (settings.VaryingTypes.Count > 0 && !settings.VaryingTypes.Contains(param.Type.Name))
362+
{
363+
param = overloads[0][depth];
364+
}
365+
366+
current.Add(param);
367+
foreach (var variation in GenerateCombinations(name, overloads, variations, current, depth + 1, end))
368+
{
369+
yield return variation;
370+
}
371+
current.RemoveAt(current.Count - 1);
372+
}
373+
}
374+
270375
public virtual void GenerateAttributes(CppParameter cppParameter, Direction direction, CsParameterInfo parameter, List<string> attributes)
271376
{
272377
string paramAttr = $"[NativeName(NativeNameType.Param, \"{cppParameter.Name}\")]";

HexaGen/HexaGen.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
88

99
<AssemblyVersion>1.1.1</AssemblyVersion>
10-
<PackageVersion>1.1.9</PackageVersion>
10+
<PackageVersion>1.1.10</PackageVersion>
1111
<Description></Description>
1212
<PackageTags></PackageTags>
1313
<Authors>Juna Meinhold</Authors>

0 commit comments

Comments
 (0)