Skip to content

Commit 61767ac

Browse files
Required keyed services to be registered instead of returning null (#7602)
Co-authored-by: Michael Staib <michael@chillicream.com>
1 parent 7b14f54 commit 61767ac

File tree

12 files changed

+123
-25
lines changed

12 files changed

+123
-25
lines changed

src/HotChocolate/ApolloFederation/test/ApolloFederation.Tests/EntitiesResolverTests.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using HotChocolate.Execution;
55
using HotChocolate.Language;
66
using Microsoft.Extensions.DependencyInjection;
7+
using Moq;
78
using static HotChocolate.ApolloFederation.TestHelper;
89

910
namespace HotChocolate.ApolloFederation;
@@ -108,12 +109,13 @@ public async Task TestResolveViaEntityResolver_WithDataLoader()
108109
var batchScheduler = new ManualBatchScheduler();
109110
var dataLoader = new FederatedTypeDataLoader(batchScheduler, new DataLoaderOptions());
110111

111-
var context = CreateResolverContext(schema,
112+
var serviceProviderMock = new Mock<IServiceProvider>();
113+
serviceProviderMock.Setup(c => c.GetService(typeof(FederatedTypeDataLoader))).Returns(dataLoader);
114+
115+
var context = CreateResolverContext(
116+
schema,
112117
null,
113-
mock =>
114-
{
115-
mock.Setup(c => c.Service<FederatedTypeDataLoader>()).Returns(dataLoader);
116-
});
118+
mock => mock.Setup(c => c.Services).Returns(serviceProviderMock.Object));
117119

118120
var representations = RepresentationsOf(
119121
nameof(FederatedType),

src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Global.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public T Resolver<T>()
170170

171171
public T Service<T>() where T : notnull => Services.GetRequiredService<T>();
172172

173-
public T? Service<T>(object key) where T : notnull => Services.GetKeyedService<T>(key);
173+
public T Service<T>(object key) where T : notnull => Services.GetRequiredKeyedService<T>(key);
174174

175175
public object Service(Type service)
176176
{

src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Pure.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ public object Service(Type service)
212212

213213
public T Service<T>() where T : notnull => parentContext.Service<T>();
214214

215-
public T? Service<T>(object key) where T : notnull => parentContext.Service<T>(key);
215+
public T Service<T>(object key) where T : notnull => parentContext.Service<T>(key);
216216

217217
public T Resolver<T>() => parentContext.Resolver<T>();
218218

src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceExpressionHelper.cs

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ namespace HotChocolate.Resolvers.Expressions.Parameters;
1010
/// </summary>
1111
internal static class ServiceExpressionHelper
1212
{
13-
private const string _service = nameof(IResolverContext.Service);
13+
private const string _serviceResolver = nameof(GetService);
14+
private const string _keyedServiceResolver = nameof(GetKeyedService);
15+
private static readonly Expression _true = Expression.Constant(true);
16+
private static readonly Expression _false = Expression.Constant(false);
1417

1518
private static readonly MethodInfo _getServiceMethod =
16-
ParameterExpressionBuilderHelpers.ContextType.GetMethods().First(
17-
method => method.Name.Equals(_service, StringComparison.Ordinal) &&
18-
method.IsGenericMethod &&
19-
method.GetParameters().Length == 0);
19+
typeof(ServiceExpressionHelper).GetMethods().First(
20+
method => method.Name.Equals(_serviceResolver, StringComparison.Ordinal));
2021

2122
private static readonly MethodInfo _getKeyedServiceMethod =
22-
ParameterExpressionBuilderHelpers.ContextType.GetMethods().First(
23-
method => method.Name.Equals(_service, StringComparison.Ordinal) &&
24-
method.IsGenericMethod &&
25-
method.GetParameters().Length == 1);
23+
typeof(ServiceExpressionHelper).GetMethods().First(
24+
method => method.Name.Equals(_keyedServiceResolver, StringComparison.Ordinal));
2625

2726
/// <summary>
2827
/// Builds the service expression.
@@ -45,14 +44,41 @@ private static Expression BuildDefaultService(ParameterInfo parameter, Expressio
4544
{
4645
var parameterType = parameter.ParameterType;
4746
var argumentMethod = _getServiceMethod.MakeGenericMethod(parameterType);
48-
return Expression.Call(context, argumentMethod);
47+
var nullabilityContext = new NullabilityInfoContext();
48+
var nullabilityInfo = nullabilityContext.Create(parameter);
49+
var isRequired = nullabilityInfo.ReadState == NullabilityState.NotNull;
50+
return Expression.Call(argumentMethod, context, isRequired ? _true : _false);
4951
}
5052

5153
private static Expression BuildDefaultService(ParameterInfo parameter, Expression context, string key)
5254
{
5355
var parameterType = parameter.ParameterType;
54-
var argumentMethod = _getKeyedServiceMethod.MakeGenericMethod(parameterType);
56+
var argumentMethod = _getKeyedServiceMethod.MakeGenericMethod(parameterType);
5557
var keyExpression = Expression.Constant(key, typeof(object));
56-
return Expression.Call(context, argumentMethod, keyExpression);
58+
var nullabilityContext = new NullabilityInfoContext();
59+
var nullabilityInfo = nullabilityContext.Create(parameter);
60+
var isRequired = nullabilityInfo.ReadState == NullabilityState.NotNull;
61+
return Expression.Call(argumentMethod, context, keyExpression, isRequired ? _true : _false);
62+
}
63+
64+
public static TService? GetService<TService>(
65+
IResolverContext context,
66+
bool required)
67+
where TService : notnull
68+
{
69+
return required
70+
? context.Services.GetRequiredService<TService>()
71+
: context.Services.GetService<TService>();
72+
}
73+
74+
public static TService? GetKeyedService<TService>(
75+
IResolverContext context,
76+
object? key,
77+
bool required)
78+
where TService : notnull
79+
{
80+
return required
81+
? context.Services.GetRequiredKeyedService<TService>(key)
82+
: context.Services.GetKeyedService<TService>(key);
5783
}
5884
}

src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceParameterExpressionBuilder.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,16 @@ public ServiceParameterBinding(ParameterInfo parameter)
4545
{
4646
var attribute = parameter.GetCustomAttribute<ServiceAttribute>();
4747
Key = attribute?.Key;
48+
49+
var context = new NullabilityInfoContext();
50+
var nullabilityInfo = context.Create(parameter);
51+
IsRequired = nullabilityInfo.ReadState == NullabilityState.NotNull;
4852
}
4953

5054
public string? Key { get; }
5155

56+
public bool IsRequired { get; }
57+
5258
public ArgumentKind Kind => ArgumentKind.Service;
5359

5460
public bool IsPure => true;
@@ -57,10 +63,14 @@ public T Execute<T>(IResolverContext context) where T : notnull
5763
{
5864
if (Key is not null)
5965
{
60-
return context.Service<T>(Key)!;
66+
return IsRequired
67+
? context.Services.GetRequiredKeyedService<T>(Key)
68+
: context.Services.GetKeyedService<T>(Key)!;
6169
}
6270

63-
return context.Service<T>();
71+
return IsRequired
72+
? context.Services.GetRequiredService<T>()
73+
: context.Services.GetService<T>()!;
6474
}
6575
}
6676
}

src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public interface IResolverContext : IHasContextData
129129
/// <returns>
130130
/// Returns the specified service.
131131
/// </returns>
132-
T? Service<T>(object key) where T : notnull;
132+
T Service<T>(object key) where T : notnull;
133133

134134
/// <summary>
135135
/// Gets a resolver object containing one or more resolvers.

src/HotChocolate/Core/test/Types.CursorPagination.Tests/QueryableCursorPagingProviderTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ public T Service<T>() where T : notnull
697697
throw new NotImplementedException();
698698
}
699699

700-
public T? Service<T>(object key) where T : notnull
700+
public T Service<T>(object key) where T : notnull
701701
{
702702
throw new NotImplementedException();
703703
}

src/HotChocolate/Core/test/Types.Tests/Resolvers/ResolverCompilerTests.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,9 +695,11 @@ public async Task Compile_Arguments_Service()
695695
var resolver = compiler.CompileResolve(member, type).Resolver!;
696696

697697
// assert
698+
var serviceProvider = new Mock<IServiceProvider>();
699+
serviceProvider.Setup(t => t.GetService(typeof(MyService))).Returns(new MyService());
698700
var context = new Mock<IResolverContext>();
699701
context.Setup(t => t.Parent<Resolvers>()).Returns(new Resolvers());
700-
context.Setup(t => t.Service<MyService>()).Returns(new MyService());
702+
context.Setup(t => t.Services).Returns(serviceProvider.Object);
701703
var result = (bool)(await resolver(context.Object))!;
702704
Assert.True(result);
703705
}

src/HotChocolate/Core/test/Types.Tests/Resolvers/ResolverServiceTests.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#nullable enable
2+
13
using CookieCrumble;
24
using HotChocolate.Execution;
35
using HotChocolate.Types;
@@ -315,6 +317,38 @@ public async Task Resolver_KeyedService()
315317
result.MatchMarkdownSnapshot();
316318
}
317319

320+
[Fact]
321+
public async Task Resolver_Optional_KeyedService_Does_Not_Exist()
322+
{
323+
var executor =
324+
await new ServiceCollection()
325+
.AddGraphQL()
326+
.AddQueryType<QueryOptional>()
327+
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true)
328+
.BuildRequestExecutorAsync();
329+
330+
var result = await executor.ExecuteAsync("{ foo }");
331+
332+
result.MatchMarkdownSnapshot();
333+
}
334+
335+
[Fact]
336+
public async Task Resolver_Optional_KeyedService_Exists()
337+
{
338+
var executor =
339+
await new ServiceCollection()
340+
.AddKeyedSingleton("abc", (_, _) => new KeyedService("abc"))
341+
.AddKeyedSingleton("def", (_, _) => new KeyedService("def"))
342+
.AddGraphQL()
343+
.AddQueryType<QueryOptional>()
344+
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true)
345+
.BuildRequestExecutorAsync();
346+
347+
var result = await executor.ExecuteAsync("{ foo }");
348+
349+
result.MatchMarkdownSnapshot();
350+
}
351+
318352
public sealed class SayHelloService
319353
{
320354
public string Scope = "Resolver";
@@ -372,6 +406,12 @@ public string Foo([AbcService] KeyedService service)
372406
=> service.Key;
373407
}
374408

409+
public class QueryOptional
410+
{
411+
public string Foo([AbcService] KeyedService? service)
412+
=> service?.Key ?? "No Service";
413+
}
414+
375415
public class KeyedService(string key)
376416
{
377417
public string Key => key;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Resolver_Optional_KeyedService_Does_Not_Exist
2+
3+
```json
4+
{
5+
"data": {
6+
"foo": "No Service"
7+
}
8+
}
9+
```

0 commit comments

Comments
 (0)