Skip to content

Commit 8498eaa

Browse files
committed
Implement type checking for legacy_const_generics
1 parent 2ef541b commit 8498eaa

File tree

2 files changed

+87
-7
lines changed

2 files changed

+87
-7
lines changed

crates/hir_ty/src/infer/expr.rs

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::{
2828
lower::{
2929
const_or_path_to_chalk, generic_arg_to_chalk, lower_to_chalk_mutability, ParamLoweringMode,
3030
},
31-
mapping::from_chalk,
31+
mapping::{from_chalk, ToChalk},
3232
method_resolution,
3333
primitive::{self, UintTy},
3434
static_lifetime, to_chalk_trait_id,
@@ -279,21 +279,24 @@ impl<'a> InferenceContext<'a> {
279279
let callee_ty = self.infer_expr(*callee, &Expectation::none());
280280
let mut derefs = Autoderef::new(&mut self.table, callee_ty.clone());
281281
let mut res = None;
282+
let mut derefed_callee = callee_ty.clone();
282283
// manual loop to be able to access `derefs.table`
283284
while let Some((callee_deref_ty, _)) = derefs.next() {
284285
res = derefs.table.callable_sig(&callee_deref_ty, args.len());
285286
if res.is_some() {
287+
derefed_callee = callee_deref_ty;
286288
break;
287289
}
288290
}
289-
let (param_tys, ret_ty): (Vec<Ty>, Ty) = match res {
291+
let (param_tys, ret_ty) = match res {
290292
Some(res) => {
291293
let adjustments = auto_deref_adjust_steps(&derefs);
292294
self.write_expr_adj(*callee, adjustments);
293295
res
294296
}
295297
None => (Vec::new(), self.err_ty()),
296298
};
299+
let indices_to_skip = self.check_legacy_const_generics(derefed_callee, args);
297300
self.register_obligations_for_call(&callee_ty);
298301

299302
let expected_inputs = self.expected_inputs_for_expected_output(
@@ -302,7 +305,7 @@ impl<'a> InferenceContext<'a> {
302305
param_tys.clone(),
303306
);
304307

305-
self.check_call_arguments(args, &expected_inputs, &param_tys);
308+
self.check_call_arguments(args, &expected_inputs, &param_tys, &indices_to_skip);
306309
self.normalize_associated_types_in(ret_ty)
307310
}
308311
Expr::MethodCall { receiver, args, method_name, generic_args } => self
@@ -952,7 +955,7 @@ impl<'a> InferenceContext<'a> {
952955
let expected_inputs =
953956
self.expected_inputs_for_expected_output(expected, ret_ty.clone(), param_tys.clone());
954957

955-
self.check_call_arguments(args, &expected_inputs, &param_tys);
958+
self.check_call_arguments(args, &expected_inputs, &param_tys, &[]);
956959
self.normalize_associated_types_in(ret_ty)
957960
}
958961

@@ -983,24 +986,40 @@ impl<'a> InferenceContext<'a> {
983986
}
984987
}
985988

986-
fn check_call_arguments(&mut self, args: &[ExprId], expected_inputs: &[Ty], param_tys: &[Ty]) {
989+
fn check_call_arguments(
990+
&mut self,
991+
args: &[ExprId],
992+
expected_inputs: &[Ty],
993+
param_tys: &[Ty],
994+
skip_indices: &[u32],
995+
) {
987996
// Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
988997
// We do this in a pretty awful way: first we type-check any arguments
989998
// that are not closures, then we type-check the closures. This is so
990999
// that we have more information about the types of arguments when we
9911000
// type-check the functions. This isn't really the right way to do this.
9921001
for &check_closures in &[false, true] {
1002+
let mut skip_indices = skip_indices.into_iter().copied().fuse().peekable();
9931003
let param_iter = param_tys.iter().cloned().chain(repeat(self.err_ty()));
9941004
let expected_iter = expected_inputs
9951005
.iter()
9961006
.cloned()
9971007
.chain(param_iter.clone().skip(expected_inputs.len()));
998-
for ((&arg, param_ty), expected_ty) in args.iter().zip(param_iter).zip(expected_iter) {
1008+
for (idx, ((&arg, param_ty), expected_ty)) in
1009+
args.iter().zip(param_iter).zip(expected_iter).enumerate()
1010+
{
9991011
let is_closure = matches!(&self.body[arg], Expr::Lambda { .. });
10001012
if is_closure != check_closures {
10011013
continue;
10021014
}
10031015

1016+
while skip_indices.peek().map_or(false, |i| *i < idx as u32) {
1017+
skip_indices.next();
1018+
}
1019+
if skip_indices.peek().copied() == Some(idx as u32) {
1020+
continue;
1021+
}
1022+
10041023
// the difference between param_ty and expected here is that
10051024
// expected is the parameter when the expected *return* type is
10061025
// taken into account. So in `let _: &[i32] = identity(&[1, 2])`
@@ -1140,6 +1159,49 @@ impl<'a> InferenceContext<'a> {
11401159
}
11411160
}
11421161

1162+
/// Returns the argument indices to skip.
1163+
fn check_legacy_const_generics(&mut self, callee: Ty, args: &[ExprId]) -> Vec<u32> {
1164+
let (func, subst) = match callee.kind(Interner) {
1165+
TyKind::FnDef(fn_id, subst) => {
1166+
let callable = CallableDefId::from_chalk(self.db, *fn_id);
1167+
let func = match callable {
1168+
CallableDefId::FunctionId(f) => f,
1169+
_ => return Vec::new(),
1170+
};
1171+
(func, subst)
1172+
}
1173+
_ => return Vec::new(),
1174+
};
1175+
1176+
let data = self.db.function_data(func);
1177+
if data.legacy_const_generics_indices.is_empty() {
1178+
return Vec::new();
1179+
}
1180+
1181+
// only use legacy const generics if the param count matches with them
1182+
if data.params.len() + data.legacy_const_generics_indices.len() != args.len() {
1183+
return Vec::new();
1184+
}
1185+
1186+
// check legacy const parameters
1187+
for (subst_idx, arg_idx) in data.legacy_const_generics_indices.iter().copied().enumerate() {
1188+
let arg = match subst.at(Interner, subst_idx).constant(Interner) {
1189+
Some(c) => c,
1190+
None => continue, // not a const parameter?
1191+
};
1192+
if arg_idx >= args.len() as u32 {
1193+
continue;
1194+
}
1195+
let _ty = arg.data(Interner).ty.clone();
1196+
let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly
1197+
self.infer_expr(args[arg_idx as usize], &expected);
1198+
// FIXME: evaluate and unify with the const
1199+
}
1200+
let mut indices = data.legacy_const_generics_indices.clone();
1201+
indices.sort();
1202+
indices
1203+
}
1204+
11431205
fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> {
11441206
let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
11451207
let rhs_ty = self.resolve_ty_shallow(&rhs_ty);

crates/hir_ty/src/tests/simple.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use expect_test::expect;
22

3-
use super::{check_infer, check_types};
3+
use super::{check_infer, check_no_mismatches, check_types};
44

55
#[test]
66
fn infer_box() {
@@ -2624,3 +2624,21 @@ pub mod prelude {
26242624
"#,
26252625
);
26262626
}
2627+
2628+
#[test]
2629+
fn legacy_const_generics() {
2630+
check_no_mismatches(
2631+
r#"
2632+
#[rustc_legacy_const_generics(1, 3)]
2633+
fn mixed<const N1: &'static str, const N2: bool>(
2634+
a: u8,
2635+
b: i8,
2636+
) {}
2637+
2638+
fn f() {
2639+
mixed(0, "", -1, true);
2640+
mixed::<"", true>(0, -1);
2641+
}
2642+
"#,
2643+
);
2644+
}

0 commit comments

Comments
 (0)