@@ -5,7 +5,7 @@ use ide_db::{
5
5
use itertools:: Itertools ;
6
6
use syntax:: {
7
7
ast:: { self , Expr } ,
8
- match_ast, AstNode , TextRange , TextSize ,
8
+ match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange , TextSize ,
9
9
} ;
10
10
11
11
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
38
38
} ;
39
39
40
40
let type_ref = & ret_type. ty ( ) ?;
41
- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
41
+ let Some ( hir :: Adt :: Enum ( ret_enum ) ) = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) else { return None ; } ;
42
42
let result_enum =
43
43
FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
44
-
45
- if !matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
44
+ if ret_enum != result_enum {
46
45
return None ;
47
46
}
48
47
48
+ let Some ( ok_type) = unwrap_result_type ( type_ref) else { return None ; } ;
49
+
49
50
acc. add (
50
51
AssistId ( "unwrap_result_return_type" , AssistKind :: RefactorRewrite ) ,
51
52
"Unwrap Result return type" ,
@@ -64,26 +65,22 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
64
65
} ) ;
65
66
for_each_tail_expr ( & body, tail_cb) ;
66
67
67
- let mut is_unit_type = false ;
68
- if let Some ( ( _, inner_type) ) = type_ref. to_string ( ) . split_once ( '<' ) {
69
- let inner_type = match inner_type. split_once ( ',' ) {
70
- Some ( ( success_inner_type, _) ) => success_inner_type,
71
- None => inner_type,
72
- } ;
73
- let new_ret_type = inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ;
74
- if new_ret_type == "()" {
75
- is_unit_type = true ;
76
- let text_range = TextRange :: new (
77
- ret_type. syntax ( ) . text_range ( ) . start ( ) ,
78
- ret_type. syntax ( ) . text_range ( ) . end ( ) + TextSize :: from ( 1u32 ) ,
79
- ) ;
80
- builder. delete ( text_range)
81
- } else {
82
- builder. replace (
83
- type_ref. syntax ( ) . text_range ( ) ,
84
- inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ,
85
- )
68
+ let is_unit_type = is_unit_type ( & ok_type) ;
69
+ if is_unit_type {
70
+ let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
71
+
72
+ if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
73
+ if token. kind ( ) == SyntaxKind :: WHITESPACE {
74
+ text_range = TextRange :: new (
75
+ text_range. start ( ) ,
76
+ text_range. end ( ) + TextSize :: from ( 1u32 ) ,
77
+ ) ;
78
+ }
86
79
}
80
+
81
+ builder. delete ( text_range) ;
82
+ } else {
83
+ builder. replace ( type_ref. syntax ( ) . text_range ( ) , ok_type. syntax ( ) . text ( ) ) ;
87
84
}
88
85
89
86
for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +131,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134
131
}
135
132
}
136
133
134
+ // Tries to extract `T` from `Result<T, E>`.
135
+ fn unwrap_result_type ( ty : & ast:: Type ) -> Option < ast:: Type > {
136
+ let ast:: Type :: PathType ( path_ty) = ty else { return None ; } ;
137
+ let Some ( path) = path_ty. path ( ) else { return None ; } ;
138
+ let Some ( segment) = path. first_segment ( ) else { return None ; } ;
139
+ let Some ( generic_arg_list) = segment. generic_arg_list ( ) else { return None ; } ;
140
+ let generic_args: Vec < _ > = generic_arg_list. generic_args ( ) . collect ( ) ;
141
+ let Some ( ast:: GenericArg :: TypeArg ( ok_type) ) = generic_args. first ( ) else { return None ; } ;
142
+ ok_type. ty ( )
143
+ }
144
+
145
+ fn is_unit_type ( ty : & ast:: Type ) -> bool {
146
+ let ast:: Type :: TupleType ( tuple) = ty else { return false } ;
147
+ tuple. fields ( ) . next ( ) . is_none ( )
148
+ }
149
+
137
150
#[ cfg( test) ]
138
151
mod tests {
139
152
use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -173,6 +186,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173
186
r#"
174
187
fn foo() {
175
188
}
189
+ "# ,
190
+ ) ;
191
+
192
+ // Unformatted return type
193
+ check_assist (
194
+ unwrap_result_return_type,
195
+ r#"
196
+ //- minicore: result
197
+ fn foo() -> Result<(), Box<dyn Error$0>>{
198
+ Ok(())
199
+ }
200
+ "# ,
201
+ r#"
202
+ fn foo() {
203
+ }
176
204
"# ,
177
205
) ;
178
206
}
@@ -1014,6 +1042,54 @@ fn foo(the_field: u32) -> u32 {
1014
1042
}
1015
1043
the_field
1016
1044
}
1045
+ "# ,
1046
+ ) ;
1047
+ }
1048
+
1049
+ #[ test]
1050
+ fn unwrap_result_return_type_nested_type ( ) {
1051
+ check_assist (
1052
+ unwrap_result_return_type,
1053
+ r#"
1054
+ //- minicore: result, option
1055
+ fn foo() -> Result<Option<i32$0>, ()> {
1056
+ Ok(Some(42))
1057
+ }
1058
+ "# ,
1059
+ r#"
1060
+ fn foo() -> Option<i32> {
1061
+ Some(42)
1062
+ }
1063
+ "# ,
1064
+ ) ;
1065
+
1066
+ check_assist (
1067
+ unwrap_result_return_type,
1068
+ r#"
1069
+ //- minicore: result, option
1070
+ fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1071
+ Ok(None)
1072
+ }
1073
+ "# ,
1074
+ r#"
1075
+ fn foo() -> Option<Result<i32, ()>> {
1076
+ None
1077
+ }
1078
+ "# ,
1079
+ ) ;
1080
+
1081
+ check_assist (
1082
+ unwrap_result_return_type,
1083
+ r#"
1084
+ //- minicore: result, option, iterators
1085
+ fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1086
+ Ok(Some(42).into_iter())
1087
+ }
1088
+ "# ,
1089
+ r#"
1090
+ fn foo() -> impl Iterator<Item = i32> {
1091
+ Some(42).into_iter()
1092
+ }
1017
1093
"# ,
1018
1094
) ;
1019
1095
}
0 commit comments