Skip to content

Commit e4f2e9f

Browse files
authored
Added support for total count to ToBatchPageAsync. (#7944)
1 parent 65fdaed commit e4f2e9f

File tree

3 files changed

+451
-1
lines changed

3 files changed

+451
-1
lines changed

src/HotChocolate/Pagination/src/Pagination.EntityFramework/Extensions/PagingQueryableExtensions.cs

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using System.Collections.Concurrent;
12
using System.Collections.Immutable;
23
using System.Linq.Expressions;
4+
using System.Reflection;
35
using HotChocolate.Pagination.Expressions;
46
using static HotChocolate.Pagination.Expressions.ExpressionHelpers;
57

@@ -11,6 +13,7 @@ namespace HotChocolate.Pagination;
1113
public static class PagingQueryableExtensions
1214
{
1315
private static readonly AsyncLocal<InterceptorHolder> _interceptor = new();
16+
private static readonly ConcurrentDictionary<(Type, Type), Expression> _countExpressionCache = new();
1417

1518
/// <summary>
1619
/// Executes a query with paging and returns the selected page.
@@ -208,6 +211,49 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
208211
where TKey : notnull
209212
=> ToBatchPageAsync<TKey, TValue, TValue>(source, keySelector, t => t, arguments, cancellationToken);
210213

214+
/// <summary>
215+
/// Executes a batch query with paging and returns the selected pages for each parent.
216+
/// </summary>
217+
/// <param name="source">
218+
/// The queryable to be paged.
219+
/// </param>
220+
/// <param name="keySelector">
221+
/// A function to select the key of the parent.
222+
/// </param>
223+
/// <param name="arguments">
224+
/// The paging arguments.
225+
/// </param>
226+
/// <param name="includeTotalCount">
227+
/// If set to <c>true</c> the total count will be included in the result.
228+
/// </param>
229+
/// <param name="cancellationToken">
230+
/// The cancellation token.
231+
/// </param>
232+
/// <typeparam name="TKey">
233+
/// The type of the parent key.
234+
/// </typeparam>
235+
/// <typeparam name="TValue">
236+
/// The type of the items in the queryable.
237+
/// </typeparam>
238+
/// <returns></returns>
239+
/// <exception cref="ArgumentException">
240+
/// If the queryable does not have any keys specified.
241+
/// </exception>
242+
public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue>(
243+
this IQueryable<TValue> source,
244+
Expression<Func<TValue, TKey>> keySelector,
245+
PagingArguments arguments,
246+
bool includeTotalCount,
247+
CancellationToken cancellationToken = default)
248+
where TKey : notnull
249+
=> ToBatchPageAsync<TKey, TValue, TValue>(
250+
source,
251+
keySelector,
252+
t => t,
253+
arguments,
254+
includeTotalCount: includeTotalCount,
255+
cancellationToken);
256+
211257
/// <summary>
212258
/// Executes a batch query with paging and returns the selected pages for each parent.
213259
/// </summary>
@@ -239,11 +285,55 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
239285
/// <exception cref="ArgumentException">
240286
/// If the queryable does not have any keys specified.
241287
/// </exception>
288+
public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue, TElement>(
289+
this IQueryable<TElement> source,
290+
Expression<Func<TElement, TKey>> keySelector,
291+
Func<TElement, TValue> valueSelector,
292+
PagingArguments arguments,
293+
CancellationToken cancellationToken = default)
294+
where TKey : notnull
295+
=> ToBatchPageAsync(source, keySelector, valueSelector, arguments, includeTotalCount: false, cancellationToken);
296+
297+
/// <summary>
298+
/// Executes a batch query with paging and returns the selected pages for each parent.
299+
/// </summary>
300+
/// <param name="source">
301+
/// The queryable to be paged.
302+
/// </param>
303+
/// <param name="keySelector">
304+
/// A function to select the key of the parent.
305+
/// </param>
306+
/// <param name="valueSelector">
307+
/// A function to select the value of the items in the queryable.
308+
/// </param>
309+
/// <param name="arguments">
310+
/// The paging arguments.
311+
/// </param>
312+
/// <param name="includeTotalCount">
313+
/// If set to <c>true</c> the total count will be included in the result.
314+
/// </param>
315+
/// <param name="cancellationToken">
316+
/// The cancellation token.
317+
/// </param>
318+
/// <typeparam name="TKey">
319+
/// The type of the parent key.
320+
/// </typeparam>
321+
/// <typeparam name="TValue">
322+
/// The type of the items in the queryable.
323+
/// </typeparam>
324+
/// <typeparam name="TElement">
325+
/// The type of the items in the queryable.
326+
/// </typeparam>
327+
/// <returns></returns>
328+
/// <exception cref="ArgumentException">
329+
/// If the queryable does not have any keys specified.
330+
/// </exception>
242331
public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue, TElement>(
243332
this IQueryable<TElement> source,
244333
Expression<Func<TElement, TKey>> keySelector,
245334
Func<TElement, TValue> valueSelector,
246335
PagingArguments arguments,
336+
bool includeTotalCount,
247337
CancellationToken cancellationToken = default)
248338
where TKey : notnull
249339
{
@@ -263,6 +353,12 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
263353
nameof(arguments));
264354
}
265355

356+
Dictionary<TKey, int>? counts = null;
357+
if (includeTotalCount)
358+
{
359+
counts = await GetBatchCountsAsync(source, keySelector, cancellationToken);
360+
}
361+
266362
source = QueryHelpers.EnsureOrderPropsAreSelected(source);
267363

268364
// we need to move the ordering into the select expression we are constructing
@@ -308,13 +404,67 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
308404
builder.Add(valueSelector(item.Items[i]));
309405
}
310406

311-
var page = CreatePage(builder.ToImmutable(), arguments, keys, item.Items.Count);
407+
var totalCount = counts?.GetValueOrDefault(item.Key);
408+
var page = CreatePage(builder.ToImmutable(), arguments, keys, item.Items.Count, totalCount);
312409
map.Add(item.Key, page);
313410
}
314411

315412
return map;
316413
}
317414

415+
private static async Task<Dictionary<TKey, int>> GetBatchCountsAsync<TElement, TKey>(
416+
IQueryable<TElement> source,
417+
Expression<Func<TElement, TKey>> keySelector,
418+
CancellationToken cancellationToken)
419+
where TKey : notnull
420+
{
421+
var query = source
422+
.GroupBy(keySelector)
423+
.Select(GetOrCreateCountSelector<TElement, TKey>());
424+
425+
TryGetQueryInterceptor()?.OnBeforeExecute(query);
426+
427+
return await query.ToDictionaryAsync(t => t.Key, t => t.Count, cancellationToken);
428+
}
429+
430+
private static Expression<Func<IGrouping<TKey, TElement>, CountResult<TKey>>> GetOrCreateCountSelector<TElement, TKey>()
431+
{
432+
return (Expression<Func<IGrouping<TKey, TElement>, CountResult<TKey>>>)
433+
_countExpressionCache.GetOrAdd(
434+
(typeof(TKey), typeof(TElement)),
435+
static _ =>
436+
{
437+
var groupingType = typeof(IGrouping<,>).MakeGenericType(typeof(TKey), typeof(TElement));
438+
var param = Expression.Parameter(groupingType, "g");
439+
var keyProperty = Expression.Property(param, nameof(IGrouping<TKey, TElement>.Key));
440+
var countMethod = typeof(Enumerable)
441+
.GetMethods(BindingFlags.Static | BindingFlags.Public)
442+
.First(m => m.Name == nameof(Enumerable.Count) && m.GetParameters().Length == 1)
443+
.MakeGenericMethod(typeof(TElement));
444+
var countCall = Expression.Call(countMethod, param);
445+
446+
var resultCtor = typeof(CountResult<TKey>).GetConstructor(Type.EmptyTypes)!;
447+
var newExpr = Expression.New(resultCtor);
448+
449+
var bindings = new List<MemberBinding>
450+
{
451+
Expression.Bind(typeof(CountResult<TKey>).GetProperty(nameof(CountResult<TKey>.Key))!,
452+
keyProperty),
453+
Expression.Bind(typeof(CountResult<TKey>).GetProperty(nameof(CountResult<TKey>.Count))!,
454+
countCall)
455+
};
456+
457+
var body = Expression.MemberInit(newExpr, bindings);
458+
return Expression.Lambda<Func<IGrouping<TKey, TElement>, CountResult<TKey>>>(body, param);
459+
});
460+
}
461+
462+
private class CountResult<TKey>
463+
{
464+
public required TKey Key { get; set; }
465+
public required int Count { get; set; }
466+
}
467+
318468
private static Page<T> CreatePage<T>(
319469
ImmutableArray<T> items,
320470
PagingArguments arguments,

src/HotChocolate/Pagination/test/Pagination.EntityFramework.Tests/InterfaceIntegrationTests.cs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,61 @@ public async Task Query_Owner_Animals()
6565

6666
var operationResult = result.ExpectOperationResult();
6767

68+
#if NET9_0_OR_GREATER
69+
await Snapshot.Create("NET_9_0")
70+
#else
71+
await Snapshot.Create()
72+
#endif
73+
.AddQueries(queries)
74+
.Add(operationResult.WithExtensions(ImmutableDictionary<string, object?>.Empty))
75+
.MatchMarkdownAsync();
76+
}
77+
78+
[Fact]
79+
public async Task Query_Owner_Animals_With_TotalCount()
80+
{
81+
var connectionString = CreateConnectionString();
82+
await SeedAsync(connectionString);
83+
84+
var queries = new List<QueryInfo>();
85+
using var capture = new CapturePagingQueryInterceptor(queries);
86+
87+
var result = await new ServiceCollection()
88+
.AddScoped(_ => new AnimalContext(connectionString))
89+
.AddGraphQL()
90+
.AddQueryType<Query>()
91+
.AddTypeExtension(typeof(OwnerExtensionsWithTotalCount))
92+
.AddDataLoader<AnimalsByOwnerWithCountDataLoader>()
93+
.AddObjectType<Cat>()
94+
.AddObjectType<Dog>()
95+
.AddPagingArguments()
96+
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true)
97+
.ModifyPagingOptions(o => o.IncludeTotalCount = true)
98+
.ExecuteRequestAsync(
99+
OperationRequestBuilder.New()
100+
.SetDocument(
101+
"""
102+
{
103+
owners(first: 10) {
104+
nodes {
105+
id
106+
name
107+
pets(first: 10) {
108+
nodes {
109+
__typename
110+
id
111+
name
112+
}
113+
totalCount
114+
}
115+
}
116+
}
117+
}
118+
""")
119+
.Build());
120+
121+
var operationResult = result.ExpectOperationResult();
122+
68123
#if NET9_0_OR_GREATER
69124
await Snapshot.Create("NET_9_0")
70125
#else
@@ -314,6 +369,24 @@ public static async Task<Connection<Animal>> GetPetsAsync(
314369
.ToConnectionAsync();
315370
}
316371

372+
[ExtendObjectType<Owner>]
373+
public static class OwnerExtensionsWithTotalCount
374+
{
375+
[BindMember(nameof(Owner.Pets))]
376+
[UsePaging]
377+
public static async Task<Connection<Animal>> GetPetsAsync(
378+
[Parent("Id")] Owner owner,
379+
PagingArguments pagingArgs,
380+
AnimalsByOwnerWithCountDataLoader animalsByOwner,
381+
ISelection selection,
382+
CancellationToken cancellationToken)
383+
=> await animalsByOwner
384+
.WithPagingArguments(pagingArgs)
385+
.Select(selection)
386+
.LoadAsync(owner.Id, cancellationToken)
387+
.ToConnectionAsync();
388+
}
389+
317390
public sealed class AnimalsByOwnerDataLoader
318391
: StatefulBatchDataLoader<int, Page<Animal>>
319392
{
@@ -352,6 +425,46 @@ protected override async Task<IReadOnlyDictionary<int, Page<Animal>>> LoadBatchA
352425
cancellationToken);
353426
}
354427
}
428+
429+
public sealed class AnimalsByOwnerWithCountDataLoader
430+
: StatefulBatchDataLoader<int, Page<Animal>>
431+
{
432+
private readonly IServiceProvider _services;
433+
434+
public AnimalsByOwnerWithCountDataLoader(
435+
IServiceProvider services,
436+
IBatchScheduler batchScheduler,
437+
DataLoaderOptions options)
438+
: base(batchScheduler, options)
439+
{
440+
_services = services;
441+
}
442+
443+
protected override async Task<IReadOnlyDictionary<int, Page<Animal>>> LoadBatchAsync(
444+
IReadOnlyList<int> keys,
445+
DataLoaderFetchContext<Page<Animal>> context,
446+
CancellationToken cancellationToken)
447+
{
448+
var pagingArgs = context.GetPagingArguments();
449+
// var selector = context.GetSelector();
450+
451+
await using var scope = _services.CreateAsyncScope();
452+
var dbContext = scope.ServiceProvider.GetRequiredService<AnimalContext>();
453+
454+
return await dbContext.Owners
455+
.Where(t => keys.Contains(t.Id))
456+
.SelectMany(t => t.Pets)
457+
.OrderBy(t => t.Name)
458+
.ThenBy(t => t.Id)
459+
// selections do not work when inheritance is used for nested batching.
460+
// .Select(selector, t => t.OwnerId)
461+
.ToBatchPageAsync(
462+
t => t.OwnerId,
463+
pagingArgs,
464+
includeTotalCount: true,
465+
cancellationToken);
466+
}
467+
}
355468
}
356469

357470
file static class Extensions

0 commit comments

Comments
 (0)