Skip to content

Rust: Type inference for impl trait types with type parameters #20119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/ql/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion rust/ql/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// generated by codegen, remove this comment if you wish to edit this file
/**
* This module provides a hand-modifiable wrapper around the generated class `ImplTraitTypeRepr`.
*
* INTERNAL: Do not use.
*/

private import codeql.rust.elements.internal.generated.ImplTraitTypeRepr
private import rust

/**
* INTERNAL: This module contains the customizable definition of `ImplTraitTypeRepr` and should not
* be referenced directly.
*/
module Impl {
// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* An `impl Trait` type.
*
Expand All @@ -21,5 +22,11 @@ module Impl {
* // ^^^^^^^^^^^^^^^^^^^^^^^^^^
* ```
*/
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr { }
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr {
/**
* Gets the function for which this impl trait type occurs in the return
* type, if any.
*/
Function getFunctionReturnPos() { this.getParentNode*() = result.getRetType().getTypeRepr() }
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like us to come up with a better name for this predicate - in particular the pattern get...Pos makes my brain expect this to return a position / index, which it doesn't. On the other hand "return position" is established phrasing so we don't want to lose that. Perhaps just re-order it to getReturnPosFunction???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about getReturnPosOfFunction? We could also split this into two predicates getFunction and isReturnPos. Both of those seem clear, but we don't otherwise have a need for the general getFunction that also works for impl in arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a readability point of view I quite like the two predicates solution. Assuming it doesn't make the code much more messy elsewhere.

}
38 changes: 35 additions & 3 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,20 @@ newtype TType =
TDynTraitTypeParameter(TypeParam tp) {
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
} or
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
implTraitTypeParam(implTrait, _, tp)
} or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()

predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypeParam tp) {
tp = implTrait.getFunctionReturnPos().getGenericParamList().getTypeParam(i) and
// Only include type parameters of the function that occur inside the impl
// trait type.
exists(Path path | path.getParentNode*() = implTrait and resolvePath(path) = tp)
}

/**
* A type without type arguments.
*
Expand Down Expand Up @@ -244,7 +254,12 @@ class ImplTraitType extends Type, TImplTraitType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }
override TypeParameter getTypeParameter(int i) {
exists(TypeParam tp |
implTraitTypeParam(impl, i, tp) and
result = TImplTraitTypeParameter(impl, tp)
)
}

override string toString() { result = impl.toString() }

Expand Down Expand Up @@ -283,7 +298,7 @@ class DynTraitType extends Type, TDynTraitType {
class ImplTraitReturnType extends ImplTraitType {
private Function function;

ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
ImplTraitReturnType() { function = impl.getFunctionReturnPos() }

override Function getFunction() { result = function }
}
Expand Down Expand Up @@ -417,6 +432,21 @@ class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
override Location getLocation() { result = typeParam.getLocation() }
}

class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
private TypeParam typeParam;
private ImplTraitTypeRepr implTrait;

ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTrait, typeParam) }

TypeParam getTypeParam() { result = typeParam }

ImplTraitTypeRepr getImplTraitTypeRepr() { result = implTrait }

override string toString() { result = "impl(" + typeParam.toString() + ")" }

override Location getLocation() { result = typeParam.getLocation() }
}

/** An implicit reference type parameter. */
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
override string toString() { result = "&T" }
Expand Down Expand Up @@ -531,5 +561,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
}

final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
override TypeParameter getATypeParameter() { none() }
override TypeParameter getATypeParameter() {
implTraitTypeParam(this, _, result.(TypeParamTypeParameter).getTypeParam())
}
}
32 changes: 19 additions & 13 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,44 @@ private module Input1 implements InputSig1<Location> {

int getTypeParameterId(TypeParameter tp) {
tp =
rank[result](TypeParameter tp0, int kind, int id |
rank[result](TypeParameter tp0, int kind, int id1, int id2 |
tp0 instanceof ArrayTypeParameter and
kind = 0 and
id = 0
id1 = 0 and
id2 = 0
or
tp0 instanceof RefTypeParameter and
kind = 0 and
id = 1
id1 = 0 and
id2 = 1
or
tp0 instanceof SliceTypeParameter and
kind = 0 and
id = 2
id1 = 0 and
id2 = 2
or
kind = 1 and
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
id1 = 0 and
id2 = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
id1 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getImplTraitTypeRepr()) and
id2 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getTypeParam())
or
kind = 3 and
id1 = 0 and
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
node = tp0.(SelfTypeParameter).getTrait() or
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
)
or
exists(TupleTypeParameter ttp, int maxArity |
maxArity = max(int i | i = any(TupleType tt).getArity()) and
tp0 = ttp and
kind = 3 and
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
)
kind = 4 and
id1 = tp0.(TupleTypeParameter).getTupleType().getArity() and
id2 = tp0.(TupleTypeParameter).getIndex()
|
tp0 order by kind, id
tp0 order by kind, id1, id2
)
}
}
Expand Down
6 changes: 6 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr
override Type resolveTypeAt(TypePath typePath) {
typePath.isEmpty() and
result.(ImplTraitType).getImplTraitTypeRepr() = this
or
exists(ImplTraitTypeParameter tp |
this = tp.getImplTraitTypeRepr() and
typePath = TypePath::singleton(tp) and
result = TTypeParamTypeParameter(tp.getTypeParam())
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:2213:13:2213:31 | ...::from(...) |
| main.rs:2214:13:2214:31 | ...::from(...) |
| main.rs:2215:13:2215:31 | ...::from(...) |
| main.rs:2221:13:2221:31 | ...::from(...) |
| main.rs:2222:13:2222:31 | ...::from(...) |
| main.rs:2223:13:2223:31 | ...::from(...) |
| main.rs:2238:13:2238:31 | ...::from(...) |
| main.rs:2239:13:2239:31 | ...::from(...) |
| main.rs:2240:13:2240:31 | ...::from(...) |
| main.rs:2246:13:2246:31 | ...::from(...) |
| main.rs:2247:13:2247:31 | ...::from(...) |
| main.rs:2248:13:2248:31 | ...::from(...) |
27 changes: 26 additions & 1 deletion rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1873,8 +1873,10 @@ mod async_ {
}

mod impl_trait {
#[derive(Copy, Clone)]
struct S1;
struct S2;
struct S3<T3>(T3);

trait Trait1 {
fn f1(&self) {} // Trait1f1
Expand Down Expand Up @@ -1906,6 +1908,13 @@ mod impl_trait {
}
}

impl<T: Clone> MyTrait<T> for S3<T> {
fn get_a(&self) -> T {
let S3(t) = self;
t.clone()
}
}

fn get_a_my_trait() -> impl MyTrait<S2> {
S1
}
Expand All @@ -1914,6 +1923,18 @@ mod impl_trait {
t.get_a() // $ target=MyTrait::get_a
}

fn get_a_my_trait2<T: Clone>(x: T) -> impl MyTrait<T> {
S3(x)
}

fn get_a_my_trait3<T: Clone>(x: T) -> Option<impl MyTrait<T>> {
Some(S3(x))
}

fn get_a_my_trait4<T: Clone>(x: T) -> (impl MyTrait<T>, impl MyTrait<T>) {
(S3(x.clone()), S3(x)) // $ target=clone
}

fn uses_my_trait2<A>(t: impl MyTrait<A>) -> A {
t.get_a() // $ target=MyTrait::get_a
}
Expand All @@ -1927,6 +1948,10 @@ mod impl_trait {
let a = get_a_my_trait(); // $ target=get_a_my_trait
let c = uses_my_trait2(a); // $ type=c:S2 target=uses_my_trait2
let d = uses_my_trait2(S1); // $ type=d:S2 target=uses_my_trait2
let e = get_a_my_trait2(S1).get_a(); // $ target=get_a_my_trait2 target=MyTrait::get_a type=e:S1
// For this function the `impl` type does not appear in the root of the return type
let f = get_a_my_trait3(S1).unwrap().get_a(); // $ target=get_a_my_trait3 target=unwrap target=MyTrait::get_a type=f:S1
let g = get_a_my_trait4(S1).0.get_a(); // $ target=get_a_my_trait4 target=MyTrait::get_a type=g:S1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in what happens when we have nested impl types, for example I've seen:

impl Iterator<Item = (impl Into<String>, Resource)>

I haven't tested anything but it feels like there's a danger the type parameter String could become associated with both impl types on this line due to the way implTraitTypeParam is written? Is it even wrong if it does???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question that I haven't thought of.

With how it works right now they would become type parameters of both the impl types. Whether that's right or wrong I'm not sure about. My guess is that it's more correct to only create a type parameter of the inner impl. I can try and write some tests for that in a follow up PR and take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR is fine, if you're keeping track. :)

}
}

Expand Down Expand Up @@ -2385,7 +2410,7 @@ mod tuples {

let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
match pair {
(0,0) => print!("unexpected"),
(0, 0) => print!("unexpected"),
_ => print!("expected"),
}
let x = pair.0; // $ type=x:i32
Expand Down
Loading
Loading