Skip to content

Commit 55af261

Browse files
authored
Merge pull request CommunityToolkit#654 from CommunityToolkit/user/sergiopedri/tweak-di-generator
Update DI generator for lazy-init, better codegen
2 parents ae22771 + c15a065 commit 55af261

File tree

2 files changed

+64
-51
lines changed

2 files changed

+64
-51
lines changed

components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ namespace CommunityToolkit.Extensions.DependencyInjection.SourceGenerators.Model
1010
/// A model for a singleton service registration.
1111
/// </summary>
1212
/// <param name="RegistrationKind">The registration kind for the service.</param>
13+
/// <param name="ImplementationTypeName">The type name of the implementation type.</param>
1314
/// <param name="ImplementationFullyQualifiedTypeName">The fully qualified type name of the implementation type.</param>
1415
/// <param name="RequiredServiceFullyQualifiedTypeNames">The fully qualified type names of dependent services for <paramref name="ImplementationFullyQualifiedTypeName"/>.</param>
1516
/// <param name="ServiceFullyQualifiedTypeNames">The fully qualified type names for the services to register for <paramref name="ImplementationFullyQualifiedTypeName"/>.</param>
1617
internal sealed record RegisteredServiceInfo(
1718
ServiceRegistrationKind RegistrationKind,
19+
string ImplementationTypeName,
1820
string ImplementationFullyQualifiedTypeName,
1921
EquatableArray<string> RequiredServiceFullyQualifiedTypeNames,
2022
EquatableArray<string> ServiceFullyQualifiedTypeNames);

components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
143143
// Create the model fully describing the current service registration
144144
serviceInfo.Add(new RegisteredServiceInfo(
145145
RegistrationKind: registrationKind,
146+
ImplementationTypeName: implementationType.Name,
146147
ImplementationFullyQualifiedTypeName: implementationTypeName,
147148
ServiceFullyQualifiedTypeNames: serviceTypeNames,
148149
RequiredServiceFullyQualifiedTypeNames: constructorArgumentTypes));
@@ -166,8 +167,11 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
166167
/// <returns>A <see cref="CompilationUnitSyntax"/> instance with the gathered info.</returns>
167168
public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
168169
{
170+
using ImmutableArrayBuilder<LocalFunctionStatementSyntax> localFunctions = ImmutableArrayBuilder<LocalFunctionStatementSyntax>.Rent();
169171
using ImmutableArrayBuilder<StatementSyntax> registrationStatements = ImmutableArrayBuilder<StatementSyntax>.Rent();
170172

173+
int index = -1;
174+
171175
foreach (RegisteredServiceInfo serviceInfo in info.Services)
172176
{
173177
// The first service type always acts as "main" registration, and should always be present
@@ -176,6 +180,9 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
176180
continue;
177181
}
178182

183+
// Increment the index we use to disambiguate the generated local function names (starting from 0)
184+
index++;
185+
179186
using ImmutableArrayBuilder<ArgumentSyntax> constructorArguments = ImmutableArrayBuilder<ArgumentSyntax>.Rent();
180187

181188
// Prepare the dependent services for the implementation type
@@ -199,54 +206,55 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
199206
// Prepare the method name, either AddSingleton or AddTransient
200207
string registrationMethod = $"Add{serviceInfo.RegistrationKind}";
201208

202-
// Special case when the service is a singleton and no dependent services are present, just use eager instantiation instead:
209+
// Prepare the name of the factory local function
210+
string factoryMethod = $"Get{serviceInfo.ImplementationTypeName}_{index}";
211+
212+
// Prepare the local function for the registration (to improve lambda caching):
203213
//
204-
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), new <IMPLEMENTATION_TYPE>());
205-
if (serviceInfo.RegistrationKind == ServiceRegistrationKind.Singleton && constructorArguments.Count == 0)
206-
{
207-
registrationStatements.Add(
208-
ExpressionStatement(
209-
InvocationExpression(
210-
MemberAccessExpression(
211-
SyntaxKind.SimpleMemberAccessExpression,
212-
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
213-
IdentifierName("AddSingleton")))
214-
.AddArgumentListArguments(
215-
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
216-
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
217-
Argument(
218-
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
219-
.WithArgumentList(ArgumentList())))));
220-
}
221-
else
222-
{
223-
// Register the main implementation type when at least a dependent service is needed:
224-
//
225-
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), static services => new <IMPLEMENTATION_TYPE>(<CONSTRUCTOR_ARGUMENTS>));
226-
registrationStatements.Add(
227-
ExpressionStatement(
228-
InvocationExpression(
229-
MemberAccessExpression(
230-
SyntaxKind.SimpleMemberAccessExpression,
231-
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
232-
IdentifierName(registrationMethod)))
233-
.AddArgumentListArguments(
234-
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
235-
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
236-
Argument(
237-
SimpleLambdaExpression(Parameter(Identifier("services")))
238-
.AddModifiers(Token(SyntaxKind.StaticKeyword))
239-
.WithExpressionBody(
240-
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
241-
.AddArgumentListArguments(constructorArguments.ToArray()))))));
242-
}
214+
// static object <FACTORY_METHOD>(global::System.IServiceProvider services)
215+
// {
216+
// return new <IMPLEMENTATION_TYPE>(<CONSTRUCTOR_ARGUMENTS>);
217+
// }
218+
localFunctions.Add(
219+
LocalFunctionStatement(
220+
PredefinedType(Token(SyntaxKind.ObjectKeyword)),
221+
Identifier(factoryMethod))
222+
.AddModifiers(Token(SyntaxKind.StaticKeyword))
223+
.AddParameterListParameters(
224+
Parameter(Identifier("services"))
225+
.WithType(IdentifierName("global::System.IServiceProvider")))
226+
.AddBodyStatements(
227+
ReturnStatement(
228+
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
229+
.AddArgumentListArguments(constructorArguments.ToArray()))));
230+
231+
// Register the main implementation type:
232+
//
233+
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), new global::Func<global::System.IServiceProvider, object>(<FACTORY_METHOD>));
234+
registrationStatements.Add(
235+
ExpressionStatement(
236+
InvocationExpression(
237+
MemberAccessExpression(
238+
SyntaxKind.SimpleMemberAccessExpression,
239+
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
240+
IdentifierName(registrationMethod)))
241+
.AddArgumentListArguments(
242+
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
243+
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
244+
Argument(
245+
ObjectCreationExpression(
246+
GenericName(Identifier("global::System.Func"))
247+
.AddTypeArgumentListArguments(
248+
IdentifierName("global::System.IServiceProvider"),
249+
PredefinedType(Token(SyntaxKind.ObjectKeyword))))
250+
.AddArgumentListArguments(Argument(IdentifierName(factoryMethod)))))));
243251

244252
// Register all secondary services, if any
245253
foreach (string dependentServiceType in dependentServiceTypeNames)
246254
{
247255
// Register the main implementation type:
248256
//
249-
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<DEPENDENT_SERVICE_TYPE>), static services => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices<ROOT_SERVICE_TYPE>(services));
257+
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<DEPENDENT_SERVICE_TYPE>), new global::System.Func<global::System.IServiceProvider, object>(global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices<ROOT_SERVICE_TYPE>));
250258
registrationStatements.Add(
251259
ExpressionStatement(
252260
InvocationExpression(
@@ -258,16 +266,17 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
258266
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
259267
Argument(TypeOfExpression(IdentifierName(dependentServiceType))),
260268
Argument(
261-
SimpleLambdaExpression(Parameter(Identifier("services")))
262-
.AddModifiers(Token(SyntaxKind.StaticKeyword))
263-
.WithExpressionBody(
264-
InvocationExpression(
265-
MemberAccessExpression(
266-
SyntaxKind.SimpleMemberAccessExpression,
267-
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
268-
GenericName(Identifier("GetRequiredService"))
269-
.AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName))))
270-
.AddArgumentListArguments(Argument(IdentifierName("services"))))))));
269+
ObjectCreationExpression(
270+
GenericName("global::System.Func")
271+
.AddTypeArgumentListArguments(
272+
IdentifierName("global::System.IServiceProvider"),
273+
PredefinedType(Token(SyntaxKind.ObjectKeyword))))
274+
.AddArgumentListArguments(Argument(
275+
MemberAccessExpression(
276+
SyntaxKind.SimpleMemberAccessExpression,
277+
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
278+
GenericName(Identifier("GetRequiredService"))
279+
.AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName)))))))));
271280
}
272281
}
273282

@@ -294,6 +303,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
294303
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
295304
// <MODIFIERS> <RETURN_TYPE> <METHOD_NAME>(global::Microsoft.Extensions.DependencyInjection.IServiceCollection <PARAMETER_NAME>)
296305
// {
306+
// <LOCAL_FUNCTIONS>
297307
// <REGISTRATION_STATEMENTS>
298308
// }
299309
MethodDeclarationSyntax configureServicesMethodDeclaration =
@@ -302,6 +312,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
302312
.AddParameterListParameters(
303313
Parameter(Identifier(info.Method.ServiceCollectionParameterName))
304314
.WithType(IdentifierName("global::Microsoft.Extensions.DependencyInjection.IServiceCollection")))
315+
.AddBodyStatements(localFunctions.ToArray())
305316
.AddBodyStatements(registrationStatements.ToArray())
306317
.AddAttributeLists(
307318
AttributeList(SingletonSeparatedList(

0 commit comments

Comments
 (0)