1
+ using System . Collections . Concurrent ;
1
2
using System . Collections . Immutable ;
2
3
using System . Linq . Expressions ;
4
+ using System . Reflection ;
3
5
using HotChocolate . Pagination . Expressions ;
4
6
using static HotChocolate . Pagination . Expressions . ExpressionHelpers ;
5
7
@@ -11,6 +13,7 @@ namespace HotChocolate.Pagination;
11
13
public static class PagingQueryableExtensions
12
14
{
13
15
private static readonly AsyncLocal < InterceptorHolder > _interceptor = new ( ) ;
16
+ private static readonly ConcurrentDictionary < ( Type , Type ) , Expression > _countExpressionCache = new ( ) ;
14
17
15
18
/// <summary>
16
19
/// Executes a query with paging and returns the selected page.
@@ -208,6 +211,49 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
208
211
where TKey : notnull
209
212
=> ToBatchPageAsync < TKey , TValue , TValue > ( source , keySelector , t => t , arguments , cancellationToken ) ;
210
213
214
+ /// <summary>
215
+ /// Executes a batch query with paging and returns the selected pages for each parent.
216
+ /// </summary>
217
+ /// <param name="source">
218
+ /// The queryable to be paged.
219
+ /// </param>
220
+ /// <param name="keySelector">
221
+ /// A function to select the key of the parent.
222
+ /// </param>
223
+ /// <param name="arguments">
224
+ /// The paging arguments.
225
+ /// </param>
226
+ /// <param name="includeTotalCount">
227
+ /// If set to <c>true</c> the total count will be included in the result.
228
+ /// </param>
229
+ /// <param name="cancellationToken">
230
+ /// The cancellation token.
231
+ /// </param>
232
+ /// <typeparam name="TKey">
233
+ /// The type of the parent key.
234
+ /// </typeparam>
235
+ /// <typeparam name="TValue">
236
+ /// The type of the items in the queryable.
237
+ /// </typeparam>
238
+ /// <returns></returns>
239
+ /// <exception cref="ArgumentException">
240
+ /// If the queryable does not have any keys specified.
241
+ /// </exception>
242
+ public static ValueTask < Dictionary < TKey , Page < TValue > > > ToBatchPageAsync < TKey , TValue > (
243
+ this IQueryable < TValue > source ,
244
+ Expression < Func < TValue , TKey > > keySelector ,
245
+ PagingArguments arguments ,
246
+ bool includeTotalCount ,
247
+ CancellationToken cancellationToken = default )
248
+ where TKey : notnull
249
+ => ToBatchPageAsync < TKey , TValue , TValue > (
250
+ source ,
251
+ keySelector ,
252
+ t => t ,
253
+ arguments ,
254
+ includeTotalCount : includeTotalCount ,
255
+ cancellationToken ) ;
256
+
211
257
/// <summary>
212
258
/// Executes a batch query with paging and returns the selected pages for each parent.
213
259
/// </summary>
@@ -239,11 +285,55 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
239
285
/// <exception cref="ArgumentException">
240
286
/// If the queryable does not have any keys specified.
241
287
/// </exception>
288
+ public static ValueTask < Dictionary < TKey , Page < TValue > > > ToBatchPageAsync < TKey , TValue , TElement > (
289
+ this IQueryable < TElement > source ,
290
+ Expression < Func < TElement , TKey > > keySelector ,
291
+ Func < TElement , TValue > valueSelector ,
292
+ PagingArguments arguments ,
293
+ CancellationToken cancellationToken = default )
294
+ where TKey : notnull
295
+ => ToBatchPageAsync ( source , keySelector , valueSelector , arguments , includeTotalCount : false , cancellationToken ) ;
296
+
297
+ /// <summary>
298
+ /// Executes a batch query with paging and returns the selected pages for each parent.
299
+ /// </summary>
300
+ /// <param name="source">
301
+ /// The queryable to be paged.
302
+ /// </param>
303
+ /// <param name="keySelector">
304
+ /// A function to select the key of the parent.
305
+ /// </param>
306
+ /// <param name="valueSelector">
307
+ /// A function to select the value of the items in the queryable.
308
+ /// </param>
309
+ /// <param name="arguments">
310
+ /// The paging arguments.
311
+ /// </param>
312
+ /// <param name="includeTotalCount">
313
+ /// If set to <c>true</c> the total count will be included in the result.
314
+ /// </param>
315
+ /// <param name="cancellationToken">
316
+ /// The cancellation token.
317
+ /// </param>
318
+ /// <typeparam name="TKey">
319
+ /// The type of the parent key.
320
+ /// </typeparam>
321
+ /// <typeparam name="TValue">
322
+ /// The type of the items in the queryable.
323
+ /// </typeparam>
324
+ /// <typeparam name="TElement">
325
+ /// The type of the items in the queryable.
326
+ /// </typeparam>
327
+ /// <returns></returns>
328
+ /// <exception cref="ArgumentException">
329
+ /// If the queryable does not have any keys specified.
330
+ /// </exception>
242
331
public static async ValueTask < Dictionary < TKey , Page < TValue > > > ToBatchPageAsync < TKey , TValue , TElement > (
243
332
this IQueryable < TElement > source ,
244
333
Expression < Func < TElement , TKey > > keySelector ,
245
334
Func < TElement , TValue > valueSelector ,
246
335
PagingArguments arguments ,
336
+ bool includeTotalCount ,
247
337
CancellationToken cancellationToken = default )
248
338
where TKey : notnull
249
339
{
@@ -263,6 +353,12 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
263
353
nameof ( arguments ) ) ;
264
354
}
265
355
356
+ Dictionary < TKey , int > ? counts = null ;
357
+ if ( includeTotalCount )
358
+ {
359
+ counts = await GetBatchCountsAsync ( source , keySelector , cancellationToken ) ;
360
+ }
361
+
266
362
source = QueryHelpers . EnsureOrderPropsAreSelected ( source ) ;
267
363
268
364
// we need to move the ordering into the select expression we are constructing
@@ -308,13 +404,67 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
308
404
builder . Add ( valueSelector ( item . Items [ i ] ) ) ;
309
405
}
310
406
311
- var page = CreatePage ( builder . ToImmutable ( ) , arguments , keys , item . Items . Count ) ;
407
+ var totalCount = counts ? . GetValueOrDefault ( item . Key ) ;
408
+ var page = CreatePage ( builder . ToImmutable ( ) , arguments , keys , item . Items . Count , totalCount ) ;
312
409
map . Add ( item . Key , page ) ;
313
410
}
314
411
315
412
return map ;
316
413
}
317
414
415
+ private static async Task < Dictionary < TKey , int > > GetBatchCountsAsync < TElement , TKey > (
416
+ IQueryable < TElement > source ,
417
+ Expression < Func < TElement , TKey > > keySelector ,
418
+ CancellationToken cancellationToken )
419
+ where TKey : notnull
420
+ {
421
+ var query = source
422
+ . GroupBy ( keySelector )
423
+ . Select ( GetOrCreateCountSelector < TElement , TKey > ( ) ) ;
424
+
425
+ TryGetQueryInterceptor ( ) ? . OnBeforeExecute ( query ) ;
426
+
427
+ return await query . ToDictionaryAsync ( t => t . Key , t => t . Count , cancellationToken ) ;
428
+ }
429
+
430
+ private static Expression < Func < IGrouping < TKey , TElement > , CountResult < TKey > > > GetOrCreateCountSelector < TElement , TKey > ( )
431
+ {
432
+ return ( Expression < Func < IGrouping < TKey , TElement > , CountResult < TKey > > > )
433
+ _countExpressionCache . GetOrAdd (
434
+ ( typeof ( TKey ) , typeof ( TElement ) ) ,
435
+ static _ =>
436
+ {
437
+ var groupingType = typeof ( IGrouping < , > ) . MakeGenericType ( typeof ( TKey ) , typeof ( TElement ) ) ;
438
+ var param = Expression . Parameter ( groupingType , "g" ) ;
439
+ var keyProperty = Expression . Property ( param , nameof ( IGrouping < TKey , TElement > . Key ) ) ;
440
+ var countMethod = typeof ( Enumerable )
441
+ . GetMethods ( BindingFlags . Static | BindingFlags . Public )
442
+ . First ( m => m . Name == nameof ( Enumerable . Count ) && m . GetParameters ( ) . Length == 1 )
443
+ . MakeGenericMethod ( typeof ( TElement ) ) ;
444
+ var countCall = Expression . Call ( countMethod , param ) ;
445
+
446
+ var resultCtor = typeof ( CountResult < TKey > ) . GetConstructor ( Type . EmptyTypes ) ! ;
447
+ var newExpr = Expression . New ( resultCtor ) ;
448
+
449
+ var bindings = new List < MemberBinding >
450
+ {
451
+ Expression . Bind ( typeof ( CountResult < TKey > ) . GetProperty ( nameof ( CountResult < TKey > . Key ) ) ! ,
452
+ keyProperty ) ,
453
+ Expression . Bind ( typeof ( CountResult < TKey > ) . GetProperty ( nameof ( CountResult < TKey > . Count ) ) ! ,
454
+ countCall )
455
+ } ;
456
+
457
+ var body = Expression . MemberInit ( newExpr , bindings ) ;
458
+ return Expression . Lambda < Func < IGrouping < TKey , TElement > , CountResult < TKey > > > ( body , param ) ;
459
+ } ) ;
460
+ }
461
+
462
+ private class CountResult < TKey >
463
+ {
464
+ public required TKey Key { get ; set ; }
465
+ public required int Count { get ; set ; }
466
+ }
467
+
318
468
private static Page < T > CreatePage < T > (
319
469
ImmutableArray < T > items ,
320
470
PagingArguments arguments ,
0 commit comments