Skip to content

Commit 97f2667

Browse files
LucStrrstam
authored andcommitted
CSHARP-4935: Support casting from an interface to a type that implements that interface in LINQ queries.
1 parent 9ced8a1 commit 97f2667

File tree

5 files changed

+126
-13
lines changed

5 files changed

+126
-13
lines changed

src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public DiscriminatedInterfaceSerializer()
7575
/// <exception cref="System.ArgumentException">interfaceType</exception>
7676
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
7777
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention)
78-
: this(discriminatorConvention, CreateInterfaceSerializer())
78+
: this(discriminatorConvention, CreateInterfaceSerializer(), objectSerializer: null)
7979
{
8080
}
8181

@@ -87,6 +87,19 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo
8787
/// <exception cref="System.ArgumentException">interfaceType</exception>
8888
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
8989
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer)
90+
: this(discriminatorConvention, interfaceSerializer, objectSerializer: null)
91+
{
92+
}
93+
94+
/// <summary>
95+
/// Initializes a new instance of the <see cref="DiscriminatedInterfaceSerializer{TInterface}" /> class.
96+
/// </summary>
97+
/// <param name="discriminatorConvention">The discriminator convention.</param>
98+
/// <param name="interfaceSerializer">The interface serializer (necessary to support LINQ queries).</param>
99+
/// <param name="objectSerializer">The serializer that is used to serialize any objects.</param>
100+
/// <exception cref="System.ArgumentException">interfaceType</exception>
101+
/// <exception cref="System.ArgumentNullException">interfaceType</exception>
102+
public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer<TInterface> interfaceSerializer, IBsonSerializer<object> objectSerializer)
90103
{
91104
var interfaceTypeInfo = typeof(TInterface).GetTypeInfo();
92105
if (!interfaceTypeInfo.IsInterface)
@@ -97,20 +110,25 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo
97110

98111
_interfaceType = typeof(TInterface);
99112
_discriminatorConvention = discriminatorConvention ?? interfaceSerializer.GetDiscriminatorConvention();
100-
_objectSerializer = BsonSerializer.LookupSerializer<object>();
101-
if (_objectSerializer is ObjectSerializer standardObjectSerializer)
102-
{
103-
_objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention);
104-
}
105-
else
113+
_interfaceSerializer = interfaceSerializer;
114+
115+
if (objectSerializer == null)
106116
{
107-
if (discriminatorConvention != null)
117+
objectSerializer = BsonSerializer.LookupSerializer<object>();
118+
if (objectSerializer is ObjectSerializer standardObjectSerializer)
119+
{
120+
Func<Type, bool> allowedTypes = (Type type) => typeof(TInterface).IsAssignableFrom(type);
121+
objectSerializer = standardObjectSerializer
122+
.WithDiscriminatorConvention(_discriminatorConvention)
123+
.WithAllowedTypes(allowedTypes, allowedTypes);
124+
}
125+
else
108126
{
109127
throw new BsonSerializationException("Can't set discriminator convention on custom object serializer.");
110128
}
111129
}
112130

113-
_interfaceSerializer = interfaceSerializer;
131+
_objectSerializer = objectSerializer;
114132
}
115133

116134
// public properties

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ private static bool IsConvertToBaseType(Type sourceType, Type targetType)
154154

155155
private static bool IsConvertToDerivedType(Type sourceType, Type targetType)
156156
{
157-
return targetType.IsSubclassOf(sourceType);
157+
return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface
158158
}
159159

160160
private static bool IsConvertToNullableType(Type targetType)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private static bool IsConvertToBaseType(Type fieldType, Type targetType)
8787

8888
private static bool IsConvertToDerivedType(Type fieldType, Type targetType)
8989
{
90-
return targetType.IsSubclassOfOrImplements(fieldType);
90+
return fieldType.IsAssignableFrom(targetType); // targetType either derives from fieldType or implements fieldType interface
9191
}
9292

9393
private static bool IsConvertToNullable(Type fieldType, Type targetType)

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,47 @@ public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work
192192
result.EnumAsNullableInt.Should().Be(2);
193193
}
194194

195+
[Fact]
196+
public void Should_translate_from_base_interface_to_derived_class_on_method_call()
197+
{
198+
var collection = GetInterfaceCollection();
199+
var queryable = collection.AsQueryable()
200+
.Select(p => new DerivedClass
201+
{
202+
Id = p.Id,
203+
A = ((DerivedClass)p).A.ToUpper()
204+
});
195205

206+
var stages = Translate(collection, queryable);
207+
AssertStages(
208+
stages,
209+
"{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }");
210+
211+
var result = queryable.Single();
212+
result.Id.Should().Be(1);
213+
result.A.Should().Be("ABC");
214+
}
215+
216+
[Fact]
217+
public void Should_translate_from_base_interface_to_derived_class_on_projection()
218+
{
219+
var collection = GetInterfaceCollection();
220+
var queryable = collection.AsQueryable()
221+
.Select(p => new DerivedClass()
222+
{
223+
Id = p.Id,
224+
A = ((DerivedClass)p).A
225+
});
226+
227+
var stages = Translate(collection, queryable);
228+
AssertStages(
229+
stages,
230+
"{ '$project' : { _id : '$_id', A : '$A' } }");
231+
232+
var result = queryable.Single();
233+
result.Id.Should().Be(1);
234+
result.A.Should().Be("abc");
235+
}
196236

197237
private IMongoCollection<BaseClass> GetCollection()
198238
{
@@ -209,7 +249,31 @@ private IMongoCollection<BaseClass> GetCollection()
209249
return collection;
210250
}
211251

212-
private class BaseClass
252+
private IMongoCollection<IBaseInterface> GetInterfaceCollection()
253+
{
254+
var collection = GetCollection<IBaseInterface>("test");
255+
CreateCollection(collection, new DerivedClass()
256+
{
257+
Id = 1,
258+
A = "abc",
259+
Enum = Enum.Two,
260+
NullableEnum = Enum.Two,
261+
EnumAsInt = 2,
262+
EnumAsNullableInt = 2
263+
});
264+
return collection;
265+
}
266+
267+
private interface IBaseInterface
268+
{
269+
public int Id { get; set; }
270+
public Enum Enum { get; set; }
271+
public Enum? NullableEnum { get; set; }
272+
public int EnumAsInt { get; set; }
273+
public int? EnumAsNullableInt { get; set; }
274+
}
275+
276+
private class BaseClass : IBaseInterface
213277
{
214278
public int Id { get; set; }
215279
public Enum Enum { get; set; }

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ public void Filter_using_convert_nullable_enum_to_underlying_type_should_work()
8686
result.Id.Should().Be(2);
8787
}
8888

89+
[Fact]
90+
public void Filter_using_field_from_implementing_type_should_work()
91+
{
92+
var collection = GetInterfaceCollection();
93+
94+
var filter = Builders<IData>.Filter.Eq(x => ((Data)x).AdditionalValue, "value");
95+
96+
var result = collection.Find(filter).Single();
97+
result.Id.Should().Be(2);
98+
}
99+
89100
private IMongoCollection<Data> GetCollection()
90101
{
91102
var collection = GetCollection<Data>("test");
@@ -96,13 +107,33 @@ private IMongoCollection<Data> GetCollection()
96107
return collection;
97108
}
98109

99-
private class Data
110+
private IMongoCollection<IData> GetInterfaceCollection()
111+
{
112+
var collection = GetCollection<IData>("test");
113+
CreateCollection(
114+
collection,
115+
new Data { Id = 1, Enum = Enum.One, NullableEnum = Enum.One, EnumAsInt = 1, EnumAsNullableInt = 1 },
116+
new Data { Id = 2, Enum = Enum.Two, NullableEnum = Enum.Two, EnumAsInt = 2, EnumAsNullableInt = 2, AdditionalValue = "value"});
117+
return collection;
118+
}
119+
120+
private interface IData
121+
{
122+
public int Id { get; set; }
123+
public Enum Enum { get; set; }
124+
public Enum? NullableEnum { get; set; }
125+
public int EnumAsInt { get; set; }
126+
public int? EnumAsNullableInt { get; set; }
127+
}
128+
129+
private class Data : IData
100130
{
101131
public int Id { get; set; }
102132
public Enum Enum { get; set; }
103133
public Enum? NullableEnum { get; set; }
104134
public int EnumAsInt { get; set; }
105135
public int? EnumAsNullableInt { get; set; }
136+
public string AdditionalValue { get; set; }
106137
}
107138

108139
private enum Enum

0 commit comments

Comments
 (0)