Skip to content

Commit 3f94a90

Browse files
committed
Infer FnSig from Fn traits
1 parent 6654055 commit 3f94a90

File tree

5 files changed

+197
-18
lines changed

5 files changed

+197
-18
lines changed

crates/ra_hir_ty/src/infer/coerce.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ impl<'a> InferenceContext<'a> {
3838
// Special case: two function types. Try to coerce both to
3939
// pointers to have a chance at getting a match. See
4040
// https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916
41-
let sig1 = ty1.callable_sig(self.db).expect("FnDef without callable sig");
42-
let sig2 = ty2.callable_sig(self.db).expect("FnDef without callable sig");
41+
let sig1 = self.callable_sig(ty1).expect("FnDef without callable sig");
42+
let sig2 = self.callable_sig(ty2).expect("FnDef without callable sig");
4343
let ptr_ty1 = Ty::fn_ptr(sig1);
4444
let ptr_ty2 = Ty::fn_ptr(sig2);
4545
self.coerce_merge_branch(&ptr_ty1, &ptr_ty2)
@@ -93,7 +93,7 @@ impl<'a> InferenceContext<'a> {
9393

9494
// `{function_type}` -> `fn()`
9595
(ty_app!(TypeCtor::FnDef(_)), ty_app!(TypeCtor::FnPtr { .. })) => {
96-
match from_ty.callable_sig(self.db) {
96+
match self.callable_sig(&from_ty) {
9797
None => return false,
9898
Some(sig) => {
9999
from_ty = Ty::fn_ptr(sig);

crates/ra_hir_ty/src/infer/expr.rs

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ use ra_syntax::ast::RangeOp;
1515

1616
use crate::{
1717
autoderef, method_resolution, op,
18-
traits::InEnvironment,
18+
traits::{builtin::get_fn_trait, FnTrait, InEnvironment, SolutionVariables},
1919
utils::{generics, variant_data, Generics},
20-
ApplicationTy, Binders, CallableDef, InferTy, IntTy, Mutability, Obligation, Rawness, Substs,
21-
TraitRef, Ty, TypeCtor,
20+
ApplicationTy, Binders, CallableDef, FnSig, InferTy, IntTy, Mutability, Obligation, Rawness,
21+
Substs, TraitRef, Ty, TypeCtor,
2222
};
2323

2424
use super::{
2525
find_breakable, BindingMode, BreakableContext, Diverges, Expectation, InferenceContext,
26-
InferenceDiagnostic, TypeMismatch,
26+
InferenceDiagnostic, Solution, TypeMismatch,
2727
};
2828

2929
impl<'a> InferenceContext<'a> {
@@ -63,6 +63,75 @@ impl<'a> InferenceContext<'a> {
6363
self.resolve_ty_as_possible(ty)
6464
}
6565

66+
fn callable_sig_from_fn_trait(&mut self, ty: &Ty) -> Option<FnSig> {
67+
if let Some(krate) = self.resolver.krate() {
68+
let fn_traits: Vec<crate::TraitId> = [FnTrait::FnOnce, FnTrait::FnMut, FnTrait::Fn]
69+
.iter()
70+
.filter_map(|f| get_fn_trait(self.db, krate, *f))
71+
.collect();
72+
for fn_trait in fn_traits {
73+
let fn_trait_data = self.db.trait_data(fn_trait);
74+
let generic_params = generics(self.db.upcast(), fn_trait.into());
75+
if generic_params.len() != 2 {
76+
continue;
77+
}
78+
79+
let arg_ty = self.table.new_type_var();
80+
let substs = Substs::build_for_generics(&generic_params)
81+
.push(ty.clone())
82+
.push(arg_ty.clone())
83+
.build();
84+
85+
let trait_ref = TraitRef { trait_: fn_trait, substs: substs.clone() };
86+
let trait_env = Arc::clone(&self.trait_env);
87+
let implements_fn_goal =
88+
self.canonicalizer().canonicalize_obligation(InEnvironment {
89+
value: Obligation::Trait(trait_ref),
90+
environment: trait_env,
91+
});
92+
if let Some(Solution::Unique(SolutionVariables(solution))) =
93+
self.db.trait_solve(krate, implements_fn_goal.value.clone())
94+
{
95+
match solution.value.as_slice() {
96+
[Ty::Apply(ApplicationTy {
97+
ctor: TypeCtor::Tuple { cardinality: _ },
98+
parameters,
99+
})] => {
100+
let output_assoc_type = match fn_trait_data
101+
.associated_types()
102+
.collect::<Vec<hir_def::TypeAliasId>>()
103+
.as_slice()
104+
{
105+
[output] => *output,
106+
_ => {
107+
continue;
108+
}
109+
};
110+
let output_proj_ty = crate::ProjectionTy {
111+
associated_ty: output_assoc_type,
112+
parameters: substs,
113+
};
114+
let return_ty = self.normalize_projection_ty(output_proj_ty);
115+
return Some(FnSig::from_params_and_return(
116+
parameters.into_iter().map(|ty| ty.clone()).collect(),
117+
return_ty,
118+
));
119+
}
120+
_ => (),
121+
}
122+
}
123+
}
124+
};
125+
None
126+
}
127+
128+
pub fn callable_sig(&mut self, ty: &Ty) -> Option<FnSig> {
129+
match ty.callable_sig(self.db) {
130+
result @ Some(_) => result,
131+
None => self.callable_sig_from_fn_trait(ty),
132+
}
133+
}
134+
66135
fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
67136
let body = Arc::clone(&self.body); // avoid borrow checker problem
68137
let ty = match &body[tgt_expr] {
@@ -198,14 +267,21 @@ impl<'a> InferenceContext<'a> {
198267
}
199268
Expr::Call { callee, args } => {
200269
let callee_ty = self.infer_expr(*callee, &Expectation::none());
201-
let (param_tys, ret_ty) = match callee_ty.callable_sig(self.db) {
202-
Some(sig) => (sig.params().to_vec(), sig.ret().clone()),
203-
None => {
204-
// Not callable
205-
// FIXME: report an error
206-
(Vec::new(), Ty::Unknown)
207-
}
208-
};
270+
let canonicalized = self.canonicalizer().canonicalize_ty(callee_ty.clone());
271+
let mut derefs = autoderef(
272+
self.db,
273+
self.resolver.krate(),
274+
InEnvironment {
275+
value: canonicalized.value.clone(),
276+
environment: self.trait_env.clone(),
277+
},
278+
);
279+
let (param_tys, ret_ty): (Vec<Ty>, Ty) = derefs
280+
.find_map(|callee_deref_ty| {
281+
self.callable_sig(&canonicalized.decanonicalize_ty(callee_deref_ty.value))
282+
.map(|sig| (sig.params().to_vec(), sig.ret().clone()))
283+
})
284+
.unwrap_or((Vec::new(), Ty::Unknown));
209285
self.register_obligations_for_call(&callee_ty);
210286
self.check_call_arguments(args, &param_tys);
211287
self.normalize_associated_types_in(ret_ty)
@@ -692,7 +768,7 @@ impl<'a> InferenceContext<'a> {
692768
let method_ty = method_ty.subst(&substs);
693769
let method_ty = self.insert_type_vars(method_ty);
694770
self.register_obligations_for_call(&method_ty);
695-
let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
771+
let (expected_receiver_ty, param_tys, ret_ty) = match self.callable_sig(&method_ty) {
696772
Some(sig) => {
697773
if !sig.params().is_empty() {
698774
(sig.params()[0].clone(), sig.params()[1..].to_vec(), sig.ret().clone())

crates/ra_hir_ty/src/traits.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use super::{Canonical, GenericPredicate, HirDisplay, ProjectionTy, TraitRef, Ty,
1414
use self::chalk::{from_chalk, Interner, ToChalk};
1515

1616
pub(crate) mod chalk;
17-
mod builtin;
17+
pub(crate) mod builtin;
1818

1919
// This controls the maximum size of types Chalk considers. If we set this too
2020
// high, we can run into slow edge cases; if we set it too low, Chalk won't

crates/ra_hir_ty/src/traits/builtin.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,11 @@ fn super_trait_object_unsize_impl_datum(
360360
BuiltinImplData { num_vars, trait_ref, where_clauses: Vec::new(), assoc_ty_values: Vec::new() }
361361
}
362362

363-
fn get_fn_trait(db: &dyn HirDatabase, krate: CrateId, fn_trait: super::FnTrait) -> Option<TraitId> {
363+
pub fn get_fn_trait(
364+
db: &dyn HirDatabase,
365+
krate: CrateId,
366+
fn_trait: super::FnTrait,
367+
) -> Option<TraitId> {
364368
let target = db.lang_item(krate, fn_trait.lang_item_name().into())?;
365369
match target {
366370
LangItemTarget::TraitId(t) => Some(t),

crates/ra_ide/src/hover.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,4 +2410,103 @@ fn func(foo: i32) { if true { <|>foo; }; }
24102410
]
24112411
"###);
24122412
}
2413+
2414+
#[test]
2415+
fn infer_closure_arg() {
2416+
check_hover_result(
2417+
r#"
2418+
//- /lib.rs
2419+
2420+
enum Option<T> {
2421+
None,
2422+
Some(T)
2423+
}
2424+
2425+
fn foo() {
2426+
let s<|> = Option::None;
2427+
let f = |x: Option<i32>| {};
2428+
(&f)(s)
2429+
}
2430+
"#,
2431+
&["Option<i32>"],
2432+
);
2433+
}
2434+
2435+
#[test]
2436+
fn infer_fn_trait_arg() {
2437+
check_hover_result(
2438+
r#"
2439+
//- /lib.rs deps:std
2440+
2441+
#[lang = "fn"]
2442+
pub trait Fn<Args> {
2443+
type Output;
2444+
2445+
extern "rust-call" fn call(&self, args: Args) -> Self::Output;
2446+
}
2447+
2448+
enum Option<T> {
2449+
None,
2450+
Some(T)
2451+
}
2452+
2453+
fn foo<F, T>(f: F) -> T
2454+
where
2455+
F: Fn(Option<i32>) -> T,
2456+
{
2457+
let s<|> = None;
2458+
f(s)
2459+
}
2460+
"#,
2461+
&["Option<i32>"],
2462+
);
2463+
}
2464+
2465+
#[test]
2466+
fn infer_box_fn_arg() {
2467+
check_hover_result(
2468+
r#"
2469+
//- /lib.rs deps:std
2470+
2471+
#[lang = "fn_once"]
2472+
pub trait FnOnce<Args> {
2473+
type Output;
2474+
2475+
extern "rust-call" fn call_once(self, args: Args) -> Self::Output;
2476+
}
2477+
2478+
#[lang = "deref"]
2479+
pub trait Deref {
2480+
type Target: ?Sized;
2481+
2482+
fn deref(&self) -> &Self::Target;
2483+
}
2484+
2485+
#[lang = "owned_box"]
2486+
pub struct Box<T: ?Sized> {
2487+
inner: *mut T,
2488+
}
2489+
2490+
impl<T: ?Sized> Deref for Box<T> {
2491+
type Target = T;
2492+
2493+
fn deref(&self) -> &T {
2494+
&self.inner
2495+
}
2496+
}
2497+
2498+
enum Option<T> {
2499+
None,
2500+
Some(T)
2501+
}
2502+
2503+
fn foo() {
2504+
let s<|> = Option::None;
2505+
let f: Box<dyn FnOnce(&Option<i32>)> = box (|ps| {});
2506+
f(&s)
2507+
}
2508+
"#,
2509+
&["Option<i32>"],
2510+
);
2511+
}
24132512
}

0 commit comments

Comments
 (0)