Skip to content

Commit f432cf9

Browse files
authored
Merge pull request #20041 from paldepind/rust/type-inference-tuples
Rust: Type inference for tuples
2 parents 09dd708 + 6b366d8 commit f432cf9

File tree

12 files changed

+737
-132
lines changed

12 files changed

+737
-132
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/lib/codeql/rust/elements/internal/TuplePatImpl.qll

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `TuplePat`.
43
*
@@ -12,12 +11,25 @@ private import codeql.rust.elements.internal.generated.TuplePat
1211
* be referenced directly.
1312
*/
1413
module Impl {
14+
private import rust
15+
16+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1517
/**
1618
* A tuple pattern. For example:
1719
* ```rust
1820
* let (x, y) = (1, 2);
1921
* let (a, b, .., z) = (1, 2, 3, 4, 5);
2022
* ```
2123
*/
22-
class TuplePat extends Generated::TuplePat { }
24+
class TuplePat extends Generated::TuplePat {
25+
/**
26+
* Gets the arity of the tuple matched by this pattern, if any.
27+
*
28+
* This is the number of fields in the tuple pattern if and only if the
29+
* pattern does not contain a `..` pattern.
30+
*/
31+
int getTupleArity() {
32+
result = this.getNumberOfFields() and not this.getAField() instanceof RestPat
33+
}
34+
}
2335
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,23 @@ private import codeql.rust.elements.internal.generated.Synth
99

1010
cached
1111
newtype TType =
12-
TUnit() or
13-
TStruct(Struct s) { Stages::TypeInferenceStage::ref() } or
12+
TTuple(int arity) {
13+
arity =
14+
[
15+
any(TupleTypeRepr t).getNumberOfFields(),
16+
any(TupleExpr e).getNumberOfFields(),
17+
any(TuplePat p).getNumberOfFields()
18+
] and
19+
Stages::TypeInferenceStage::ref()
20+
} or
21+
TStruct(Struct s) or
1422
TEnum(Enum e) or
1523
TTrait(Trait t) or
1624
TArrayType() or // todo: add size?
1725
TRefType() or // todo: add mut?
1826
TImplTraitType(ImplTraitTypeRepr impl) or
1927
TSliceType() or
28+
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
2029
TTypeParamTypeParameter(TypeParam t) or
2130
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2231
TArrayTypeParameter() or
@@ -55,21 +64,33 @@ abstract class Type extends TType {
5564
abstract Location getLocation();
5665
}
5766

58-
/** The unit type `()`. */
59-
class UnitType extends Type, TUnit {
60-
UnitType() { this = TUnit() }
67+
/** A tuple type `(T, ...)`. */
68+
class TupleType extends Type, TTuple {
69+
private int arity;
70+
71+
TupleType() { this = TTuple(arity) }
6172

6273
override StructField getStructField(string name) { none() }
6374

6475
override TupleField getTupleField(int i) { none() }
6576

66-
override TypeParameter getTypeParameter(int i) { none() }
77+
override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(arity, i) }
6778

68-
override string toString() { result = "()" }
79+
/** Gets the arity of this tuple type. */
80+
int getArity() { result = arity }
81+
82+
override string toString() { result = "(T_" + arity + ")" }
6983

7084
override Location getLocation() { result instanceof EmptyLocation }
7185
}
7286

87+
/** The unit type `()`. */
88+
class UnitType extends TupleType, TTuple {
89+
UnitType() { this = TTuple(0) }
90+
91+
override string toString() { result = "()" }
92+
}
93+
7394
abstract private class StructOrEnumType extends Type {
7495
abstract ItemNode asItemNode();
7596
}
@@ -329,6 +350,30 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
329350
override Location getLocation() { result = typeAlias.getLocation() }
330351
}
331352

353+
/**
354+
* A tuple type parameter. For instance the `T` in `(T, U)`.
355+
*
356+
* Since tuples are structural their type parameters can be represented as their
357+
* positional index. The type inference library requires that type parameters
358+
* belong to a single type, so we also include the arity of the tuple type.
359+
*/
360+
class TupleTypeParameter extends TypeParameter, TTupleTypeParameter {
361+
private int arity;
362+
private int index;
363+
364+
TupleTypeParameter() { this = TTupleTypeParameter(arity, index) }
365+
366+
override string toString() { result = index.toString() + "(" + arity + ")" }
367+
368+
override Location getLocation() { result instanceof EmptyLocation }
369+
370+
/** Gets the index of this tuple type parameter. */
371+
int getIndex() { result = index }
372+
373+
/** Gets the tuple type that corresponds to this tuple type parameter. */
374+
TupleType getTupleType() { result = TTuple(arity) }
375+
}
376+
332377
/** An implicit array type parameter. */
333378
class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
334379
override string toString() { result = "[T;...]" }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ private module Input1 implements InputSig1<Location> {
103103
node = tp0.(SelfTypeParameter).getTrait() or
104104
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
105105
)
106+
or
107+
exists(TupleTypeParameter ttp, int maxArity |
108+
maxArity = max(int i | i = any(TupleType tt).getArity()) and
109+
tp0 = ttp and
110+
kind = 2 and
111+
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
112+
)
106113
|
107114
tp0 order by kind, id
108115
)
@@ -229,7 +236,7 @@ private Type inferLogicalOperationType(AstNode n, TypePath path) {
229236
private Type inferAssignmentOperationType(AstNode n, TypePath path) {
230237
n instanceof AssignmentOperation and
231238
path.isEmpty() and
232-
result = TUnit()
239+
result instanceof UnitType
233240
}
234241

235242
pragma[nomagic]
@@ -321,6 +328,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
321328
prefix1.isEmpty() and
322329
prefix2 = TypePath::singleton(TRefTypeParameter())
323330
or
331+
exists(int i, int arity |
332+
prefix1.isEmpty() and
333+
prefix2 = TypePath::singleton(TTupleTypeParameter(arity, i))
334+
|
335+
arity = n2.(TupleExpr).getNumberOfFields() and
336+
n1 = n2.(TupleExpr).getField(i)
337+
or
338+
arity = n2.(TuplePat).getTupleArity() and
339+
n1 = n2.(TuplePat).getField(i)
340+
)
341+
or
324342
exists(BlockExpr be |
325343
n1 = be and
326344
n2 = be.getStmtList().getTailExpr() and
@@ -534,6 +552,12 @@ private Type inferStructExprType(AstNode n, TypePath path) {
534552
)
535553
}
536554

555+
pragma[nomagic]
556+
private Type inferTupleRootType(AstNode n) {
557+
// `typeEquality` handles the non-root cases
558+
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
559+
}
560+
537561
pragma[nomagic]
538562
private Type inferPathExprType(PathExpr pe, TypePath path) {
539563
// nullary struct/variant constructors
@@ -1055,6 +1079,42 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
10551079
)
10561080
}
10571081

1082+
pragma[nomagic]
1083+
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
1084+
exists(int i, TypePath path0 |
1085+
fe.getIdentifier().getText() = i.toString() and
1086+
result = inferType(fe.getContainer(), path0) and
1087+
path0.isCons(TTupleTypeParameter(_, i), path) and
1088+
fe.getIdentifier().getText() = i.toString()
1089+
)
1090+
}
1091+
1092+
/** Infers the type of `t` in `t.n` when `t` is a tuple. */
1093+
private Type inferTupleContainerExprType(Expr e, TypePath path) {
1094+
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
1095+
// tuple as in:
1096+
// ```rust
1097+
// let t = (Default::default(), 2);
1098+
// let s: String = t.0;
1099+
// ```
1100+
// But it could also be a tuple struct as in:
1101+
// ```rust
1102+
// struct T(String, u32);
1103+
// let t = T(Default::default(), 2);
1104+
// let s: String = t.0;
1105+
// ```
1106+
// We need type information to flow from `t.n` to tuple type parameters of `t`
1107+
// in the former case but not the latter case. Hence we include the condition
1108+
// that the root type of `t` must be a tuple type.
1109+
exists(int i, TypePath path0, FieldExpr fe, int arity |
1110+
e = fe.getContainer() and
1111+
fe.getIdentifier().getText() = i.toString() and
1112+
arity = inferType(fe.getContainer()).(TupleType).getArity() and
1113+
result = inferType(fe, path0) and
1114+
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
1115+
)
1116+
}
1117+
10581118
/** Gets the root type of the reference node `ref`. */
10591119
pragma[nomagic]
10601120
private Type inferRefNodeType(AstNode ref) {
@@ -1943,12 +2003,19 @@ private module Cached {
19432003
or
19442004
result = inferStructExprType(n, path)
19452005
or
2006+
result = inferTupleRootType(n) and
2007+
path.isEmpty()
2008+
or
19462009
result = inferPathExprType(n, path)
19472010
or
19482011
result = inferCallExprBaseType(n, path)
19492012
or
19502013
result = inferFieldExprType(n, path)
19512014
or
2015+
result = inferTupleIndexExprType(n, path)
2016+
or
2017+
result = inferTupleContainerExprType(n, path)
2018+
or
19522019
result = inferRefNodeType(n) and
19532020
path.isEmpty()
19542021
or

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ abstract class TypeMention extends AstNode {
1414
final Type resolveType() { result = this.resolveTypeAt(TypePath::nil()) }
1515
}
1616

17+
class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
18+
override Type resolveTypeAt(TypePath path) {
19+
path.isEmpty() and
20+
result = TTuple(super.getNumberOfFields())
21+
or
22+
exists(TypePath suffix, int i |
23+
result = super.getField(i).(TypeMention).resolveTypeAt(suffix) and
24+
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfFields(), i), suffix)
25+
)
26+
}
27+
}
28+
1729
class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
1830
override Type resolveTypeAt(TypePath path) {
1931
path.isEmpty() and
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Type inference now supports tuple types.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| main.rs:445:18:445:24 | n.len() |

rust/ql/test/library-tests/dataflow/local/DataFlowStep.expected

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,7 @@ readStep
979979
| main.rs:442:25:442:29 | names | file://:0:0:0:0 | element | main.rs:442:9:442:20 | TuplePat |
980980
| main.rs:444:41:444:67 | [post] \|...\| ... | main.rs:441:9:441:20 | captured default_name | main.rs:444:41:444:67 | [post] default_name |
981981
| main.rs:444:44:444:55 | this | main.rs:441:9:441:20 | captured default_name | main.rs:444:44:444:55 | default_name |
982+
| main.rs:445:18:445:18 | [post] receiver for n | file://:0:0:0:0 | &ref | main.rs:445:18:445:18 | [post] n |
982983
| main.rs:469:13:469:13 | [post] receiver for b | file://:0:0:0:0 | &ref | main.rs:469:13:469:13 | [post] b |
983984
| main.rs:470:18:470:18 | [post] receiver for b | file://:0:0:0:0 | &ref | main.rs:470:18:470:18 | [post] b |
984985
| main.rs:481:10:481:11 | vs | file://:0:0:0:0 | element | main.rs:481:10:481:14 | vs[0] |
@@ -1078,6 +1079,7 @@ storeStep
10781079
| main.rs:429:30:429:30 | 3 | file://:0:0:0:0 | element | main.rs:429:23:429:31 | [...] |
10791080
| main.rs:432:18:432:27 | source(...) | file://:0:0:0:0 | element | main.rs:432:5:432:11 | [post] mut_arr |
10801081
| main.rs:444:41:444:67 | default_name | main.rs:441:9:441:20 | captured default_name | main.rs:444:41:444:67 | \|...\| ... |
1082+
| main.rs:445:18:445:18 | n | file://:0:0:0:0 | &ref | main.rs:445:18:445:18 | receiver for n |
10811083
| main.rs:469:13:469:13 | b | file://:0:0:0:0 | &ref | main.rs:469:13:469:13 | receiver for b |
10821084
| main.rs:470:18:470:18 | b | file://:0:0:0:0 | &ref | main.rs:470:18:470:18 | receiver for b |
10831085
| main.rs:479:15:479:24 | source(...) | file://:0:0:0:0 | element | main.rs:479:14:479:34 | [...] |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,20 +2361,36 @@ mod tuples {
23612361
}
23622362

23632363
pub fn f() {
2364-
let a = S1::get_pair(); // $ target=get_pair MISSING: type=a:?
2365-
let mut b = S1::get_pair(); // $ target=get_pair MISSING: type=b:?
2366-
let (c, d) = S1::get_pair(); // $ target=get_pair MISSING: type=c:? type=d:?
2367-
let (mut e, f) = S1::get_pair(); // $ target=get_pair MISSING: type=e: type=f:
2368-
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair MISSING: type=g:? type=h:?
2369-
2370-
a.0.foo(); // $ MISSING: target=foo
2371-
b.1.foo(); // $ MISSING: target=foo
2372-
c.foo(); // $ MISSING: target=foo
2373-
d.foo(); // $ MISSING: target=foo
2374-
e.foo(); // $ MISSING: target=foo
2375-
f.foo(); // $ MISSING: target=foo
2376-
g.foo(); // $ MISSING: target=foo
2377-
h.foo(); // $ MISSING: target=foo
2364+
let a = S1::get_pair(); // $ target=get_pair type=a:(T_2)
2365+
let mut b = S1::get_pair(); // $ target=get_pair type=b:(T_2)
2366+
let (c, d) = S1::get_pair(); // $ target=get_pair type=c:S1 type=d:S1
2367+
let (mut e, f) = S1::get_pair(); // $ target=get_pair type=e:S1 type=f:S1
2368+
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair type=g:S1 type=h:S1
2369+
2370+
a.0.foo(); // $ target=foo
2371+
b.1.foo(); // $ target=foo
2372+
c.foo(); // $ target=foo
2373+
d.foo(); // $ target=foo
2374+
e.foo(); // $ target=foo
2375+
f.foo(); // $ target=foo
2376+
g.foo(); // $ target=foo
2377+
h.foo(); // $ target=foo
2378+
2379+
// Here type information must flow from `pair.0` and `pair.1` into
2380+
// `pair` and from `(a, b)` into `a` and `b` in order for the types of
2381+
// `a` and `b` to be inferred.
2382+
let a = Default::default(); // $ target=default type=a:i64
2383+
let b = Default::default(); // $ target=default type=b:bool
2384+
let pair = (a, b); // $ type=pair:0(2).i64 type=pair:1(2).bool
2385+
let i: i64 = pair.0;
2386+
let j: bool = pair.1;
2387+
2388+
let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
2389+
match pair {
2390+
(0,0) => print!("unexpected"),
2391+
_ => print!("expected"),
2392+
}
2393+
let x = pair.0; // $ type=x:i32
23782394
}
23792395
}
23802396

0 commit comments

Comments
 (0)