Skip to content

Commit e1c9508

Browse files
committed
add include extension
1 parent 824ef6b commit e1c9508

File tree

5 files changed

+119
-15
lines changed

5 files changed

+119
-15
lines changed

JsonApiDotNetCore/Data/ResourceRepository.cs

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq.Expressions;
55
using JsonApiDotNetCore.Abstractions;
66
using System.Reflection;
7+
using JsonApiDotNetCore.Extensions;
78
using JsonApiDotNetCore.Routing;
89
using Microsoft.EntityFrameworkCore;
910

@@ -28,32 +29,53 @@ public object Get(string id)
2829
var relationalRoute = _context.Route as RelationalRoute;
2930
if (relationalRoute == null)
3031
{
31-
return GetEntityById(id);
32+
return GetEntityById(_context.Route.BaseModelType, id, null);
3233
}
33-
return GetRelated(id, relationalRoute.RelationshipName);
34+
return GetRelated(id, relationalRoute);
3435
}
3536

36-
private object GetRelated(string id, string relationshipName)
37+
private object GetRelated(string id, RelationalRoute relationalRoute)
3738
{
3839
// HACK: this would rely on lazy loading to work...will probably fail
39-
var entity = GetEntityById(id);
40-
var entityType = entity.GetType();
41-
return entityType.GetProperties().FirstOrDefault(pi => pi.Name == relationshipName).GetValue(entity);
40+
var entity = GetEntityById(relationalRoute.RelationalType, id, relationalRoute.RelationshipName);
41+
return relationalRoute.RelationalType.GetProperties().FirstOrDefault(pi => pi.Name.ToCamelCase() == relationalRoute.RelationshipName.ToCamelCase()).GetValue(entity);
4242
}
4343

44-
private object GetDbSetFromContext(string propName)
44+
45+
private IQueryable GetDbSetFromContext(string propName)
4546
{
4647
var dbContext = _context.DbContext;
47-
return dbContext.GetType().GetProperties().FirstOrDefault(pI => pI.Name == propName)?.GetValue(dbContext, null);
48+
return (IQueryable)dbContext.GetType().GetProperties().FirstOrDefault(pI => pI.Name.ToCamelCase() == propName)?.GetValue(dbContext, null);
4849
}
4950

50-
private object GetEntityById(string id)
51+
private object GetEntityById(Type modelType, string id, string includedRelationship)
5152
{
5253
// HACK: I _believe_ by casting to IEnumerable, we are loading all records into memory, if so... find a better way...
5354
// Also, we are making a BIG assumption that the resource has an attribute Id and not ResourceId which is allowed by EF
54-
return
55-
(GetDbSetFromContext(_context.Route.BaseRouteDefinition.ContextPropertyName) as IEnumerable<dynamic>)?
56-
.FirstOrDefault(x => x.Id.ToString() == id);
55+
var methodToCall = typeof(ResourceRepository).GetMethods().Single(method => method.Name.Equals("GetDbSet"));
56+
var genericMethod = methodToCall.MakeGenericMethod(modelType);
57+
genericMethod.Invoke(genericMethod, null);
58+
var dbSet = genericMethod.Invoke(this, null);
59+
60+
if (!string.IsNullOrEmpty(includedRelationship))
61+
{
62+
var includeMethod = typeof(ResourceRepository).GetMethods().Single(method => method.Name.Equals("IncludeEntity"));
63+
var genericIncludeMethod = includeMethod.MakeGenericMethod(modelType);
64+
genericIncludeMethod.Invoke(genericMethod, null);
65+
dbSet = genericIncludeMethod.Invoke(this, new []{ dbSet, includedRelationship });
66+
}
67+
68+
return (dbSet as IEnumerable<dynamic>).SingleOrDefault(x => x.Id.ToString() == id);
69+
}
70+
71+
private DbSet<T> GetDbSet<T>() where T : class
72+
{
73+
return ((DbContext) _context.DbContext).Set<T>();
74+
}
75+
76+
private IQueryable<T> IncludeEntity<T>(IQueryable<T> queryable, string includedEntityName) where T : class
77+
{
78+
return queryable.Include(includedEntityName);
5779
}
5880
}
5981
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SOURCE: https://github.com/aspnet/EntityFramework/issues/3921
2+
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Linq.Expressions;
7+
using System.Reflection;
8+
using Microsoft.EntityFrameworkCore;
9+
10+
namespace JsonApiDotNetCore.Extensions
11+
{
12+
public static class EfIncludeExtension
13+
{
14+
private static MethodInfo _include = typeof(EntityFrameworkQueryableExtensions)
15+
.GetMethod("Include");
16+
17+
private static MethodInfo _thenIncludeReference = typeof(EntityFrameworkQueryableExtensions)
18+
.GetMethods()
19+
.Where(m => m.Name == "ThenInclude")
20+
.Single(m => m.Name == "ThenInclude" &&
21+
m.GetParameters()
22+
.Single(p => p.Name == "source")
23+
.ParameterType
24+
.GetGenericArguments()[1].Name != typeof(ICollection<>).Name);
25+
26+
private static MethodInfo _thenIncludeCollection = typeof(EntityFrameworkQueryableExtensions)
27+
.GetMethods()
28+
.Where(m => m.Name == "ThenInclude")
29+
.Single(m => m.Name == "ThenInclude" &&
30+
m.GetParameters()
31+
.Single(p => p.Name == "source")
32+
.ParameterType
33+
.GetGenericArguments()[1].Name == typeof(ICollection<>).Name);
34+
35+
public static IQueryable<T> Include<T>(this IQueryable<T> query, string include)
36+
{
37+
return query.Include(include.Split('.'));
38+
}
39+
40+
public static IQueryable<T> Include<T>(this IQueryable<T> query, params string[] include)
41+
{
42+
var currentType = query.ElementType;
43+
var previousNavWasCollection = false;
44+
45+
for (int i = 0; i < include.Length; i++)
46+
{
47+
var navigationName = include[i];
48+
var navigationProperty = currentType.GetProperty(navigationName);
49+
if (navigationProperty == null)
50+
{
51+
throw new ArgumentException($"'{navigationName}' is not a valid property of '{currentType}'");
52+
}
53+
54+
var includeMethod = i == 0
55+
? _include.MakeGenericMethod(query.ElementType, navigationProperty.PropertyType)
56+
: previousNavWasCollection
57+
? _thenIncludeCollection.MakeGenericMethod(query.ElementType, currentType, navigationProperty.PropertyType)
58+
: _thenIncludeReference.MakeGenericMethod(query.ElementType, currentType, navigationProperty.PropertyType);
59+
60+
var expressionParameter = Expression.Parameter(currentType);
61+
var expression = Expression.Lambda(
62+
Expression.Property(expressionParameter, navigationName),
63+
expressionParameter);
64+
65+
query = (IQueryable<T>)includeMethod.Invoke(null, new object[] { query, expression });
66+
67+
if (navigationProperty.PropertyType.GetInterfaces().Any(x => x.Name == typeof(ICollection<>).Name))
68+
{
69+
previousNavWasCollection = true;
70+
currentType = navigationProperty.PropertyType.GetGenericArguments().Single();
71+
}
72+
else
73+
{
74+
previousNavWasCollection = false;
75+
currentType = navigationProperty.PropertyType;
76+
}
77+
}
78+
79+
return query;
80+
}
81+
}
82+
}

JsonApiDotNetCore/Extensions/PathStringExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public static string ExtractFirstSegment(this PathString path, out PathString re
1212
if (!path.HasValue) return string.Empty;
1313

1414
var splitPath = SplitPath(path);
15-
remainingSegments = new PathString(RemoveFirstSegmentFromPath(splitPath));
15+
remainingSegments = new PathString($"/{RemoveFirstSegmentFromPath(splitPath)}");
1616
return splitPath[0];
1717
}
1818

JsonApiDotNetCore/Routing/RouteBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ private PathString SetBaseRouteDefinition()
5454
if (_request.Path.StartsWithSegments(new PathString(rte.PathString), StringComparison.OrdinalIgnoreCase, out remainingPathString))
5555
{
5656
_baseRouteDefinition = rte;
57+
return remainingPathString;
5758
}
58-
return remainingPathString;
5959
}
6060
throw new Exception("Route is not defined.");
6161
}

JsonApiDotNetCore/Routing/Router.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ private void SendResponse(ObjectResult result)
7676
var context = _jsonApiContext.HttpContext;
7777
context.Response.StatusCode = result.StatusCode ?? 500;
7878
context.Response.ContentType = "application/vnd.api+json";
79-
context.Response.WriteAsync(result.Value.ToString());
79+
context.Response.WriteAsync(result.Value == null ? "" : result.Value.ToString());
8080
context.Response.Body.Flush();
8181
}
8282
}

0 commit comments

Comments
 (0)