Skip to content

Commit 734fe66

Browse files
committed
Handle nested types in unwrap_result_return_type assist
1 parent 797c2f1 commit 734fe66

File tree

1 file changed

+99
-23
lines changed

1 file changed

+99
-23
lines changed

crates/ide-assists/src/handlers/unwrap_result_return_type.rs

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ide_db::{
55
use itertools::Itertools;
66
use syntax::{
77
ast::{self, Expr},
8-
match_ast, AstNode, TextRange, TextSize,
8+
match_ast, AstNode, NodeOrToken, SyntaxKind, TextRange, TextSize,
99
};
1010

1111
use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
3838
};
3939

4040
let type_ref = &ret_type.ty()?;
41-
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
41+
let Some(hir::Adt::Enum(ret_enum)) = ctx.sema.resolve_type(type_ref)?.as_adt() else { return None; };
4242
let result_enum =
4343
FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
44-
45-
if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
44+
if ret_enum != result_enum {
4645
return None;
4746
}
4847

48+
let Some(ok_type) = unwrap_result_type(type_ref) else { return None; };
49+
4950
acc.add(
5051
AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
5152
"Unwrap Result return type",
@@ -64,26 +65,22 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
6465
});
6566
for_each_tail_expr(&body, tail_cb);
6667

67-
let mut is_unit_type = false;
68-
if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
69-
let inner_type = match inner_type.split_once(',') {
70-
Some((success_inner_type, _)) => success_inner_type,
71-
None => inner_type,
72-
};
73-
let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
74-
if new_ret_type == "()" {
75-
is_unit_type = true;
76-
let text_range = TextRange::new(
77-
ret_type.syntax().text_range().start(),
78-
ret_type.syntax().text_range().end() + TextSize::from(1u32),
79-
);
80-
builder.delete(text_range)
81-
} else {
82-
builder.replace(
83-
type_ref.syntax().text_range(),
84-
inner_type.strip_suffix('>').unwrap_or(inner_type),
85-
)
68+
let is_unit_type = is_unit_type(&ok_type);
69+
if is_unit_type {
70+
let mut text_range = ret_type.syntax().text_range();
71+
72+
if let Some(NodeOrToken::Token(token)) = ret_type.syntax().next_sibling_or_token() {
73+
if token.kind() == SyntaxKind::WHITESPACE {
74+
text_range = TextRange::new(
75+
text_range.start(),
76+
text_range.end() + TextSize::from(1u32),
77+
);
78+
}
8679
}
80+
81+
builder.delete(text_range);
82+
} else {
83+
builder.replace(type_ref.syntax().text_range(), ok_type.syntax().text());
8784
}
8885

8986
for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +131,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134131
}
135132
}
136133

134+
// Tries to extract `T` from `Result<T, E>`.
135+
fn unwrap_result_type(ty: &ast::Type) -> Option<ast::Type> {
136+
let ast::Type::PathType(path_ty) = ty else { return None; };
137+
let Some(path) = path_ty.path() else { return None; };
138+
let Some(segment) = path.first_segment() else { return None; };
139+
let Some(generic_arg_list) = segment.generic_arg_list() else { return None; };
140+
let generic_args: Vec<_> = generic_arg_list.generic_args().collect();
141+
let Some(ast::GenericArg::TypeArg(ok_type)) = generic_args.first() else { return None; };
142+
ok_type.ty()
143+
}
144+
145+
fn is_unit_type(ty: &ast::Type) -> bool {
146+
let ast::Type::TupleType(tuple) = ty else { return false };
147+
tuple.fields().next().is_none()
148+
}
149+
137150
#[cfg(test)]
138151
mod tests {
139152
use crate::tests::{check_assist, check_assist_not_applicable};
@@ -173,6 +186,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173186
r#"
174187
fn foo() {
175188
}
189+
"#,
190+
);
191+
192+
// Unformatted return type
193+
check_assist(
194+
unwrap_result_return_type,
195+
r#"
196+
//- minicore: result
197+
fn foo() -> Result<(), Box<dyn Error$0>>{
198+
Ok(())
199+
}
200+
"#,
201+
r#"
202+
fn foo() {
203+
}
176204
"#,
177205
);
178206
}
@@ -1014,6 +1042,54 @@ fn foo(the_field: u32) -> u32 {
10141042
}
10151043
the_field
10161044
}
1045+
"#,
1046+
);
1047+
}
1048+
1049+
#[test]
1050+
fn unwrap_result_return_type_nested_type() {
1051+
check_assist(
1052+
unwrap_result_return_type,
1053+
r#"
1054+
//- minicore: result, option
1055+
fn foo() -> Result<Option<i32$0>, ()> {
1056+
Ok(Some(42))
1057+
}
1058+
"#,
1059+
r#"
1060+
fn foo() -> Option<i32> {
1061+
Some(42)
1062+
}
1063+
"#,
1064+
);
1065+
1066+
check_assist(
1067+
unwrap_result_return_type,
1068+
r#"
1069+
//- minicore: result, option
1070+
fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1071+
Ok(None)
1072+
}
1073+
"#,
1074+
r#"
1075+
fn foo() -> Option<Result<i32, ()>> {
1076+
None
1077+
}
1078+
"#,
1079+
);
1080+
1081+
check_assist(
1082+
unwrap_result_return_type,
1083+
r#"
1084+
//- minicore: result, option, iterators
1085+
fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1086+
Ok(Some(42).into_iter())
1087+
}
1088+
"#,
1089+
r#"
1090+
fn foo() -> impl Iterator<Item = i32> {
1091+
Some(42).into_iter()
1092+
}
10171093
"#,
10181094
);
10191095
}

0 commit comments

Comments
 (0)