Skip to content

Commit d645b81

Browse files
committed
generate_function assist infer return type
1 parent 71b8fb7 commit d645b81

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

crates/ide_assists/src/handlers/generate_function.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct FunctionBuilder {
104104
fn_name: ast::Name,
105105
type_params: Option<ast::GenericParamList>,
106106
params: ast::ParamList,
107+
ret_type: Option<ast::RetType>,
107108
file: FileId,
108109
needs_pub: bool,
109110
}
@@ -131,8 +132,9 @@ impl FunctionBuilder {
131132
let target_module = target_module.or_else(|| ctx.sema.scope(target.syntax()).module())?;
132133
let fn_name = fn_name(&path)?;
133134
let (type_params, params) = fn_args(ctx, target_module, &call)?;
135+
let ret_type = fn_ret_type(ctx, target_module, &call);
134136

135-
Some(Self { target, fn_name, type_params, params, file, needs_pub })
137+
Some(Self { target, fn_name, type_params, params, ret_type, file, needs_pub })
136138
}
137139

138140
fn render(self) -> FunctionTemplate {
@@ -145,7 +147,7 @@ impl FunctionBuilder {
145147
self.type_params,
146148
self.params,
147149
fn_body,
148-
Some(make::ret_type(make::ty_unit())),
150+
Some(self.ret_type.unwrap_or_else(|| make::ret_type(make::ty_unit()))),
149151
);
150152
let leading_ws;
151153
let trailing_ws;
@@ -223,6 +225,23 @@ fn fn_args(
223225
Some((None, make::param_list(None, params)))
224226
}
225227

228+
fn fn_ret_type(
229+
ctx: &AssistContext,
230+
target_module: hir::Module,
231+
call: &ast::CallExpr,
232+
) -> Option<ast::RetType> {
233+
let ty = ctx.sema.type_of_expr(&ast::Expr::CallExpr(call.clone()))?;
234+
if ty.is_unknown() {
235+
return None;
236+
}
237+
238+
if let Ok(rendered) = ty.display_source_code(ctx.db(), target_module.into()) {
239+
Some(make::ret_type(make::ty(&rendered)))
240+
} else {
241+
None
242+
}
243+
}
244+
226245
/// Makes duplicate argument names unique by appending incrementing numbers.
227246
///
228247
/// ```
@@ -546,7 +565,7 @@ impl Baz {
546565
}
547566
}
548567
549-
fn bar(baz: Baz) ${0:-> ()} {
568+
fn bar(baz: Baz) ${0:-> Baz} {
550569
todo!()
551570
}
552571
",
@@ -1059,6 +1078,27 @@ pub(crate) fn bar() ${0:-> ()} {
10591078
)
10601079
}
10611080

1081+
#[test]
1082+
fn add_function_with_return_type() {
1083+
check_assist(
1084+
generate_function,
1085+
r"
1086+
fn main() {
1087+
let x: u32 = foo$0();
1088+
}
1089+
",
1090+
r"
1091+
fn main() {
1092+
let x: u32 = foo();
1093+
}
1094+
1095+
fn foo() ${0:-> u32} {
1096+
todo!()
1097+
}
1098+
",
1099+
)
1100+
}
1101+
10621102
#[test]
10631103
fn add_function_not_applicable_if_function_already_exists() {
10641104
check_assist_not_applicable(

0 commit comments

Comments
 (0)