Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 4ea29d6

Browse files
committed
Implement parameter variance inference
1 parent 873cf25 commit 4ea29d6

File tree

8 files changed

+1271
-37
lines changed

8 files changed

+1271
-37
lines changed

src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -950,22 +950,33 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar
950950

951951
pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances {
952952
let callable_def: CallableDefId = from_chalk(db, fn_def_id);
953-
let generic_params =
954-
generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def));
955953
Variances::from_iter(
956954
Interner,
957-
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
955+
db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def))
956+
.as_deref()
957+
.unwrap_or_default()
958+
.iter()
959+
.map(|v| match v {
960+
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
961+
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
962+
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
963+
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
964+
}),
958965
)
959966
}
960967

961968
pub(crate) fn adt_variance_query(
962969
db: &dyn HirDatabase,
963970
chalk_ir::AdtId(adt_id): AdtId,
964971
) -> Variances {
965-
let generic_params = generics(db.upcast(), adt_id.into());
966972
Variances::from_iter(
967973
Interner,
968-
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
974+
db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v {
975+
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
976+
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
977+
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
978+
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
979+
}),
969980
)
970981
}
971982

src/tools/rust-analyzer/crates/hir-ty/src/db.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
271271
#[ra_salsa::invoke(chalk_db::adt_variance_query)]
272272
fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances;
273273

274+
#[ra_salsa::invoke(crate::variance::variances_of)]
275+
#[ra_salsa::cycle(crate::variance::variances_of_cycle)]
276+
fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>;
277+
274278
#[ra_salsa::invoke(chalk_db::associated_ty_value_query)]
275279
fn associated_ty_value(
276280
&self,

src/tools/rust-analyzer/crates/hir-ty/src/generics.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,14 @@ impl Generics {
132132
self.params.len()
133133
}
134134

135+
pub(crate) fn len_self_lifetimes(&self) -> usize {
136+
self.params.len_lifetimes()
137+
}
138+
139+
pub(crate) fn has_trait_self(&self) -> bool {
140+
self.params.trait_self_param().is_some()
141+
}
142+
135143
/// (parent total, self param, type params, const params, impl trait list, lifetimes)
136144
pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) {
137145
let mut self_param = false;

src/tools/rust-analyzer/crates/hir-ty/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub mod traits;
5050
mod test_db;
5151
#[cfg(test)]
5252
mod tests;
53+
mod variance;
5354

5455
use std::hash::Hash;
5556

src/tools/rust-analyzer/crates/hir-ty/src/tests.rs

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
127127
None => continue,
128128
};
129129
let def_map = module.def_map(&db);
130-
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
130+
visit_module(&db, &def_map, module.local_id, &mut |it| {
131+
defs.push(match it {
132+
ModuleDefId::FunctionId(it) => it.into(),
133+
ModuleDefId::EnumVariantId(it) => it.into(),
134+
ModuleDefId::ConstId(it) => it.into(),
135+
ModuleDefId::StaticId(it) => it.into(),
136+
_ => return,
137+
})
138+
});
131139
}
132140
defs.sort_by_key(|def| match def {
133141
DefWithBodyId::FunctionId(it) => {
@@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
375383
let def_map = module.def_map(&db);
376384

377385
let mut defs: Vec<DefWithBodyId> = Vec::new();
378-
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
386+
visit_module(&db, &def_map, module.local_id, &mut |it| {
387+
defs.push(match it {
388+
ModuleDefId::FunctionId(it) => it.into(),
389+
ModuleDefId::EnumVariantId(it) => it.into(),
390+
ModuleDefId::ConstId(it) => it.into(),
391+
ModuleDefId::StaticId(it) => it.into(),
392+
_ => return,
393+
})
394+
});
379395
defs.sort_by_key(|def| match def {
380396
DefWithBodyId::FunctionId(it) => {
381397
let loc = it.lookup(&db);
@@ -405,30 +421,30 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
405421
buf
406422
}
407423

408-
fn visit_module(
424+
pub(crate) fn visit_module(
409425
db: &TestDB,
410426
crate_def_map: &DefMap,
411427
module_id: LocalModuleId,
412-
cb: &mut dyn FnMut(DefWithBodyId),
428+
cb: &mut dyn FnMut(ModuleDefId),
413429
) {
414430
visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
415431
for impl_id in crate_def_map[module_id].scope.impls() {
416432
let impl_data = db.impl_data(impl_id);
417433
for &item in impl_data.items.iter() {
418434
match item {
419435
AssocItemId::FunctionId(it) => {
420-
let def = it.into();
421-
cb(def);
422-
let body = db.body(def);
436+
let body = db.body(it.into());
437+
cb(it.into());
423438
visit_body(db, &body, cb);
424439
}
425440
AssocItemId::ConstId(it) => {
426-
let def = it.into();
427-
cb(def);
428-
let body = db.body(def);
441+
let body = db.body(it.into());
442+
cb(it.into());
429443
visit_body(db, &body, cb);
430444
}
431-
AssocItemId::TypeAliasId(_) => (),
445+
AssocItemId::TypeAliasId(it) => {
446+
cb(it.into());
447+
}
432448
}
433449
}
434450
}
@@ -437,33 +453,27 @@ fn visit_module(
437453
db: &TestDB,
438454
crate_def_map: &DefMap,
439455
scope: &ItemScope,
440-
cb: &mut dyn FnMut(DefWithBodyId),
456+
cb: &mut dyn FnMut(ModuleDefId),
441457
) {
442458
for decl in scope.declarations() {
459+
cb(decl);
443460
match decl {
444461
ModuleDefId::FunctionId(it) => {
445-
let def = it.into();
446-
cb(def);
447-
let body = db.body(def);
462+
let body = db.body(it.into());
448463
visit_body(db, &body, cb);
449464
}
450465
ModuleDefId::ConstId(it) => {
451-
let def = it.into();
452-
cb(def);
453-
let body = db.body(def);
466+
let body = db.body(it.into());
454467
visit_body(db, &body, cb);
455468
}
456469
ModuleDefId::StaticId(it) => {
457-
let def = it.into();
458-
cb(def);
459-
let body = db.body(def);
470+
let body = db.body(it.into());
460471
visit_body(db, &body, cb);
461472
}
462473
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
463474
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
464-
let def = it.into();
465-
cb(def);
466-
let body = db.body(def);
475+
let body = db.body(it.into());
476+
cb(it.into());
467477
visit_body(db, &body, cb);
468478
});
469479
}
@@ -473,7 +483,7 @@ fn visit_module(
473483
match item {
474484
AssocItemId::FunctionId(it) => cb(it.into()),
475485
AssocItemId::ConstId(it) => cb(it.into()),
476-
AssocItemId::TypeAliasId(_) => (),
486+
AssocItemId::TypeAliasId(it) => cb(it.into()),
477487
}
478488
}
479489
}
@@ -483,7 +493,7 @@ fn visit_module(
483493
}
484494
}
485495

486-
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
496+
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) {
487497
for (_, def_map) in body.blocks(db) {
488498
for (mod_id, _) in def_map.modules() {
489499
visit_module(db, &def_map, mod_id, cb);
@@ -553,7 +563,13 @@ fn salsa_bug() {
553563
let module = db.module_for_file(pos.file_id);
554564
let crate_def_map = module.def_map(&db);
555565
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
556-
db.infer(def);
566+
db.infer(match def {
567+
ModuleDefId::FunctionId(it) => it.into(),
568+
ModuleDefId::EnumVariantId(it) => it.into(),
569+
ModuleDefId::ConstId(it) => it.into(),
570+
ModuleDefId::StaticId(it) => it.into(),
571+
_ => return,
572+
});
557573
});
558574

559575
let new_text = "
@@ -586,6 +602,12 @@ fn salsa_bug() {
586602
let module = db.module_for_file(pos.file_id);
587603
let crate_def_map = module.def_map(&db);
588604
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
589-
db.infer(def);
605+
db.infer(match def {
606+
ModuleDefId::FunctionId(it) => it.into(),
607+
ModuleDefId::EnumVariantId(it) => it.into(),
608+
ModuleDefId::ConstId(it) => it.into(),
609+
ModuleDefId::StaticId(it) => it.into(),
610+
_ => return,
611+
});
590612
});
591613
}

src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) {
2424

2525
let mut captures_info = Vec::new();
2626
for def in defs {
27+
let def = match def {
28+
hir_def::ModuleDefId::FunctionId(it) => it.into(),
29+
hir_def::ModuleDefId::EnumVariantId(it) => it.into(),
30+
hir_def::ModuleDefId::ConstId(it) => it.into(),
31+
hir_def::ModuleDefId::StaticId(it) => it.into(),
32+
_ => continue,
33+
};
2734
let infer = db.infer(def);
2835
let db = &db;
2936
captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| {

src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use base_db::SourceDatabaseFileInputExt as _;
2+
use hir_def::ModuleDefId;
23
use test_fixture::WithFixture;
34

45
use crate::{db::HirDatabase, test_db::TestDB};
@@ -19,7 +20,9 @@ fn foo() -> i32 {
1920
let module = db.module_for_file(pos.file_id.file_id());
2021
let crate_def_map = module.def_map(&db);
2122
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
22-
db.infer(def);
23+
if let ModuleDefId::FunctionId(it) = def {
24+
db.infer(it.into());
25+
}
2326
});
2427
});
2528
assert!(format!("{events:?}").contains("infer"))
@@ -39,7 +42,9 @@ fn foo() -> i32 {
3942
let module = db.module_for_file(pos.file_id.file_id());
4043
let crate_def_map = module.def_map(&db);
4144
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
42-
db.infer(def);
45+
if let ModuleDefId::FunctionId(it) = def {
46+
db.infer(it.into());
47+
}
4348
});
4449
});
4550
assert!(!format!("{events:?}").contains("infer"), "{events:#?}")
@@ -66,7 +71,9 @@ fn baz() -> i32 {
6671
let module = db.module_for_file(pos.file_id.file_id());
6772
let crate_def_map = module.def_map(&db);
6873
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
69-
db.infer(def);
74+
if let ModuleDefId::FunctionId(it) = def {
75+
db.infer(it.into());
76+
}
7077
});
7178
});
7279
assert!(format!("{events:?}").contains("infer"))
@@ -91,7 +98,9 @@ fn baz() -> i32 {
9198
let module = db.module_for_file(pos.file_id.file_id());
9299
let crate_def_map = module.def_map(&db);
93100
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
94-
db.infer(def);
101+
if let ModuleDefId::FunctionId(it) = def {
102+
db.infer(it.into());
103+
}
95104
});
96105
});
97106
assert!(format!("{events:?}").matches("infer").count() == 1, "{events:#?}")

0 commit comments

Comments
 (0)