Skip to content

Commit 8238746

Browse files
authored
Merge pull request #20084 from paldepind/rust/type-inference-trait-object
Rust: Implement type inference for trait objects/`dyn` types
2 parents 5da7ae8 + b3dc6cb commit 8238746

File tree

11 files changed

+640
-311
lines changed

11 files changed

+640
-311
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/lib/codeql/rust/elements/internal/DynTraitTypeReprImpl.qll

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `DynTraitTypeRepr`.
43
*
@@ -12,6 +11,10 @@ private import codeql.rust.elements.internal.generated.DynTraitTypeRepr
1211
* be referenced directly.
1312
*/
1413
module Impl {
14+
private import rust
15+
private import codeql.rust.internal.PathResolution as PathResolution
16+
17+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1518
/**
1619
* A dynamic trait object type.
1720
*
@@ -21,5 +24,16 @@ module Impl {
2124
* // ^^^^^^^^^
2225
* ```
2326
*/
24-
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr { }
27+
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr {
28+
/** Gets the trait that this trait object refers to. */
29+
pragma[nomagic]
30+
Trait getTrait() {
31+
result =
32+
PathResolution::resolvePath(this.getTypeBoundList()
33+
.getBound(0)
34+
.getTypeRepr()
35+
.(PathTypeRepr)
36+
.getPath())
37+
}
38+
}
2539
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@ newtype TType =
2424
TArrayType() or // todo: add size?
2525
TRefType() or // todo: add mut?
2626
TImplTraitType(ImplTraitTypeRepr impl) or
27+
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
2728
TSliceType() or
2829
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
2930
TTypeParamTypeParameter(TypeParam t) or
3031
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
3132
TArrayTypeParameter() or
33+
TDynTraitTypeParameter(TypeParam tp) {
34+
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
35+
} or
3236
TRefTypeParameter() or
3337
TSelfTypeParameter(Trait t) or
3438
TSliceTypeParameter()
@@ -247,6 +251,26 @@ class ImplTraitType extends Type, TImplTraitType {
247251
override Location getLocation() { result = impl.getLocation() }
248252
}
249253

254+
class DynTraitType extends Type, TDynTraitType {
255+
Trait trait;
256+
257+
DynTraitType() { this = TDynTraitType(trait) }
258+
259+
override StructField getStructField(string name) { none() }
260+
261+
override TupleField getTupleField(int i) { none() }
262+
263+
override DynTraitTypeParameter getTypeParameter(int i) {
264+
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
265+
}
266+
267+
Trait getTrait() { result = trait }
268+
269+
override string toString() { result = "dyn " + trait.getName().toString() }
270+
271+
override Location getLocation() { result = trait.getLocation() }
272+
}
273+
250274
/**
251275
* An [impl Trait in return position][1] type, for example:
252276
*
@@ -381,6 +405,18 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
381405
override Location getLocation() { result instanceof EmptyLocation }
382406
}
383407

408+
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
409+
private TypeParam typeParam;
410+
411+
DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }
412+
413+
TypeParam getTypeParam() { result = typeParam }
414+
415+
override string toString() { result = "dyn(" + typeParam.toString() + ")" }
416+
417+
override Location getLocation() { result = typeParam.getLocation() }
418+
}
419+
384420
/** An implicit reference type parameter. */
385421
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
386422
override string toString() { result = "&T" }
@@ -465,6 +501,13 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
465501
}
466502
}
467503

504+
final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
505+
override TypeParameter getATypeParameter() {
506+
result.(TypeParamTypeParameter).getTypeParam() =
507+
this.getTrait().getGenericParamList().getATypeParam()
508+
}
509+
}
510+
468511
final class TraitTypeAbstraction extends TypeAbstraction, Trait {
469512
override TypeParameter getATypeParameter() {
470513
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ private module Input1 implements InputSig1<Location> {
9797
id = 2
9898
or
9999
kind = 1 and
100+
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
101+
or
102+
kind = 2 and
100103
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
101104
node = tp0.(TypeParamTypeParameter).getTypeParam() or
102105
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
@@ -107,7 +110,7 @@ private module Input1 implements InputSig1<Location> {
107110
exists(TupleTypeParameter ttp, int maxArity |
108111
maxArity = max(int i | i = any(TupleType tt).getArity()) and
109112
tp0 = ttp and
110-
kind = 2 and
113+
kind = 3 and
111114
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
112115
)
113116
|
@@ -189,6 +192,14 @@ private module Input2 implements InputSig2 {
189192
condition = impl and
190193
constraint = impl.getTypeBoundList().getABound().getTypeRepr()
191194
)
195+
or
196+
// a `dyn Trait` type implements `Trait`. See the comment on
197+
// `DynTypeBoundListMention` for further details.
198+
exists(DynTraitTypeRepr object |
199+
abs = object and
200+
condition = object.getTypeBoundList() and
201+
constraint = object.getTrait()
202+
)
192203
}
193204
}
194205

@@ -1715,10 +1726,16 @@ private Function getMethodFromImpl(MethodCall mc) {
17151726

17161727
bindingset[trait, name]
17171728
pragma[inline_late]
1718-
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
1729+
private Function getImplTraitMethod(ImplTraitReturnType trait, string name) {
17191730
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
17201731
}
17211732

1733+
bindingset[traitObject, name]
1734+
pragma[inline_late]
1735+
private Function getDynTraitMethod(DynTraitType traitObject, string name) {
1736+
result = getMethodSuccessor(traitObject.getTrait(), name)
1737+
}
1738+
17221739
pragma[nomagic]
17231740
private Function resolveMethodCallTarget(MethodCall mc) {
17241741
// The method comes from an `impl` block targeting the type of the receiver.
@@ -1729,7 +1746,10 @@ private Function resolveMethodCallTarget(MethodCall mc) {
17291746
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
17301747
or
17311748
// The type of the receiver is an `impl Trait` type.
1732-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1749+
result = getImplTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1750+
or
1751+
// The type of the receiver is a trait object `dyn Trait` type.
1752+
result = getDynTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
17331753
}
17341754

17351755
pragma[nomagic]
@@ -2073,6 +2093,13 @@ private module Debug {
20732093
result = resolveCallTarget(c)
20742094
}
20752095

2096+
predicate debugConditionSatisfiesConstraint(
2097+
TypeAbstraction abs, TypeMention condition, TypeMention constraint
2098+
) {
2099+
abs = getRelevantLocatable() and
2100+
Input2::conditionSatisfiesConstraint(abs, condition, constraint)
2101+
}
2102+
20762103
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
20772104
self = getRelevantLocatable() and
20782105
t = inferImplicitSelfType(self, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,64 @@ class SelfTypeParameterMention extends TypeMention instanceof Name {
309309
result = TSelfTypeParameter(trait)
310310
}
311311
}
312+
313+
class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
314+
private DynTraitType dynType;
315+
316+
DynTraitTypeReprMention() {
317+
// This excludes `DynTraitTypeRepr` elements where `getTrait` is not
318+
// defined, i.e., where path resolution can't find a trait.
319+
dynType.getTrait() = super.getTrait()
320+
}
321+
322+
override Type resolveTypeAt(TypePath path) {
323+
path.isEmpty() and
324+
result = dynType
325+
or
326+
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
327+
tp = dynType.getTypeParameter(_) and
328+
path = TypePath::cons(tp, suffix) and
329+
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
330+
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
331+
)
332+
}
333+
}
334+
335+
// We want a type of the form `dyn Trait` to implement `Trait`. If `Trait` has
336+
// type parameters then `dyn Trait` has equivalent type parameters and the
337+
// implementation should be abstracted over them.
338+
//
339+
// Intuitively we want something to the effect of:
340+
// ```
341+
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
342+
// ```
343+
// To achieve this:
344+
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
345+
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
346+
// Trait` at the root and which for every type parameter of `dyn Trait` has the
347+
// corresponding type parameter of the trait.
348+
// - `TraitMention` (which is used for other things as well) is a type mention
349+
// for the trait applied to its own type parameters.
350+
//
351+
// We arbitrarily use the `TypeBoundList` inside `DynTraitTypeRepr` to encode
352+
// this type mention, since it doesn't syntactically appear in the AST. This
353+
// works because there is a one-to-one correspondence between a trait object and
354+
// its list of type bounds.
355+
class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
356+
private Trait trait;
357+
358+
DynTypeBoundListMention() {
359+
exists(DynTraitTypeRepr dyn | this = dyn.getTypeBoundList() and trait = dyn.getTrait())
360+
}
361+
362+
override Type resolveTypeAt(TypePath path) {
363+
path.isEmpty() and
364+
result.(DynTraitType).getTrait() = trait
365+
or
366+
exists(TypeParam param |
367+
param = trait.getGenericParamList().getATypeParam() and
368+
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
369+
result = TTypeParamTypeParameter(param)
370+
)
371+
}
372+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Type inference now supports trait objects, i.e., `dyn Trait` types.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Test cases for type inference and method resolution with `dyn` types
2+
3+
use std::fmt::Debug;
4+
5+
trait MyTrait1 {
6+
// MyTrait1::m
7+
fn m(&self) -> String;
8+
}
9+
10+
trait GenericGet<A> {
11+
// GenericGet::get
12+
fn get(&self) -> A;
13+
}
14+
15+
#[derive(Clone, Debug)]
16+
struct MyStruct {
17+
value: i32,
18+
}
19+
20+
impl MyTrait1 for MyStruct {
21+
// MyStruct1::m
22+
fn m(&self) -> String {
23+
format!("MyTrait1: {}", self.value) // $ fieldof=MyStruct
24+
}
25+
}
26+
27+
#[derive(Clone, Debug)]
28+
struct GenStruct<A: Clone + Debug> {
29+
value: A,
30+
}
31+
32+
impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
33+
// GenStruct<A>::get
34+
fn get(&self) -> A {
35+
self.value.clone() // $ fieldof=GenStruct target=clone
36+
}
37+
}
38+
39+
fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
40+
a.get() // $ target=GenericGet::get
41+
}
42+
43+
fn get_box_trait<A: Clone + Debug + 'static>(a: A) -> Box<dyn GenericGet<A>> {
44+
Box::new(GenStruct { value: a }) // $ target=new
45+
}
46+
47+
fn test_basic_dyn_trait(obj: &dyn MyTrait1) {
48+
let _result = (*obj).m(); // $ target=deref target=MyTrait1::m type=_result:String
49+
}
50+
51+
fn test_generic_dyn_trait(obj: &dyn GenericGet<String>) {
52+
let _result1 = (*obj).get(); // $ target=deref target=GenericGet::get type=_result1:String
53+
let _result2 = get_a(obj); // $ target=get_a type=_result2:String
54+
}
55+
56+
fn test_poly_dyn_trait() {
57+
let obj = get_box_trait(true); // $ target=get_box_trait
58+
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
59+
}
60+
61+
pub fn test() {
62+
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
63+
test_generic_dyn_trait(&GenStruct {
64+
value: "".to_string(),
65+
}); // $ target=test_generic_dyn_trait
66+
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
67+
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,8 +2292,6 @@ mod loops {
22922292
}
22932293
}
22942294

2295-
mod dereference;
2296-
22972295
mod explicit_type_args {
22982296
struct S1<T>(T);
22992297

@@ -2461,6 +2459,9 @@ mod closures {
24612459
}
24622460
}
24632461

2462+
mod dereference;
2463+
mod dyn_type;
2464+
24642465
fn main() {
24652466
field_access::f(); // $ target=f
24662467
method_impl::f(); // $ target=f
@@ -2491,5 +2492,6 @@ fn main() {
24912492
dereference::test(); // $ target=test
24922493
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
24932494
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
2494-
closures::f() // $ target=f
2495+
closures::f(); // $ target=f
2496+
dyn_type::test(); // $ target=test
24952497
}

0 commit comments

Comments
 (0)