Skip to content

Add ide-assist: extract_to_default_generic #20295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions crates/ide-assists/src/handlers/extract_to_default_generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use ast::Name;
use either::Either::{self, Left, Right};
use ide_db::{source_change::SourceChangeBuilder, syntax_helpers::suggest_name::NameGenerator};
use syntax::{
ast::{self, AstNode, HasGenericParams, HasName, make},
syntax_editor::{Position, SyntaxEditor},
};

use crate::{AssistContext, AssistId, Assists};

// Assist: extract_to_default_generic
//
// Extracts selected type to default generic parameter.
//
// ```
// struct Foo(u32, $0String$0);
// ```
// ->
// ```
// struct Foo<T = String>(u32, T);
// ```
pub(crate) fn extract_to_default_generic(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
if ctx.has_empty_selection() {
return None;
}

let ty: Either<ast::Type, ast::ConstArg> = ctx.find_node_at_range()?;
let adt: Either<ast::Adt, Either<ast::TypeAlias, ast::Fn>> =
ty.syntax().ancestors().find_map(AstNode::cast)?;

extract_to_default_generic_impl(acc, ctx, adt, ty)
}

fn extract_to_default_generic_impl(
acc: &mut Assists,
ctx: &AssistContext<'_>,
adt: impl HasName + HasGenericParams,
ty: Either<ast::Type, ast::ConstArg>,
) -> Option<()> {
let name = adt.name()?;

let target = ty.syntax().text_range();
acc.add(
AssistId::refactor_extract("extract_to_default_generic"),
"Extract type as default generic parameter",
target,
|edit| {
let mut editor = edit.make_editor(adt.syntax());
let generic_list = get_or_create_generic_param_list(&name, &adt, &mut editor, edit);

let generic_name = generic_name(&generic_list, ty.is_right());

editor.replace(ty.syntax(), generic_name.syntax());

match ty {
Left(ty) => {
let param = make::type_default_param(generic_name, None, ty).clone_for_update();
generic_list.add_generic_param(param.into());
}
Right(n) => {
let param = make::const_default_param(generic_name, const_ty(ctx, &n), n)
.clone_for_update();
generic_list.add_generic_param(param.into());

if let Some(ast::GenericParam::ConstParam(param)) =
generic_list.generic_params().last()
&& let Some(ast::Type::InferType(ty)) = param.ty()
&& let Some(cap) = ctx.config.snippet_cap
{
let annotation = edit.make_placeholder_snippet(cap);
editor.add_annotation(ty.syntax(), annotation);
}
}
}

edit.add_file_edits(ctx.vfs_file_id(), editor);
},
)
}

fn array_index_type(n: &ast::ConstArg) -> Option<ast::Type> {
let kind = n.syntax().parent()?.kind();

if ast::ArrayType::can_cast(kind) || ast::ArrayExpr::can_cast(kind) {
Some(make::ty("usize"))
} else {
None
}
}

fn generic_name(generic_list: &ast::GenericParamList, is_const_param: bool) -> Name {
let exist_names = generic_list
.generic_params()
.filter_map(|it| match it {
ast::GenericParam::ConstParam(const_param) => const_param.name(),
ast::GenericParam::TypeParam(type_param) => type_param.name(),
ast::GenericParam::LifetimeParam(_) => None,
})
.map(|name| name.to_string())
.collect::<Vec<_>>();

let mut name_gen = NameGenerator::new_with_names(exist_names.iter().map(|name| name.as_str()));

make::name(&if is_const_param {
name_gen.suggest_name("N")
} else {
name_gen.suggest_name("T")
})
.clone_for_update()
}

fn const_ty(ctx: &AssistContext<'_>, n: &ast::ConstArg) -> ast::Type {
if let Some(expr) = n.expr()
&& let Some(ty_info) = ctx.sema.type_of_expr(&expr)
&& let Some(builtin) = ty_info.adjusted().as_builtin()
{
make::ty(builtin.name().as_str())
} else if let Some(array_index_ty) = array_index_type(n) {
array_index_ty
} else {
make::ty_placeholder()
}
}

fn get_or_create_generic_param_list(
name: &ast::Name,
adt: &impl HasGenericParams,
editor: &mut SyntaxEditor,
edit: &mut SourceChangeBuilder,
) -> ast::GenericParamList {
if let Some(list) = adt.generic_param_list() {
edit.make_mut(list)
} else {
let generic = make::generic_param_list([]).clone_for_update();
editor.insert(Position::after(name.syntax()), generic.syntax());
generic
}
}

#[cfg(test)]
mod tests {
use super::*;

use crate::tests::check_assist;

#[test]
fn test_extract_to_default_generic() {
check_assist(
extract_to_default_generic,
r#"type X = ($0i32$0, i64);"#,
r#"type X<T = i32> = (T, i64);"#,
);

check_assist(
extract_to_default_generic,
r#"type X<T> = ($0i32$0, T);"#,
r#"type X<T, T1 = i32> = (T1, T);"#,
);
}

#[test]
fn test_extract_to_default_generic_on_adt() {
check_assist(
extract_to_default_generic,
r#"struct Foo($0i32$0);"#,
r#"struct Foo<T = i32>(T);"#,
);

check_assist(
extract_to_default_generic,
r#"struct Foo<T>(T, $0i32$0);"#,
r#"struct Foo<T, T1 = i32>(T, T1);"#,
);

check_assist(
extract_to_default_generic,
r#"enum Foo { A($0i32$0), B, C(i64) };"#,
r#"enum Foo<T = i32> { A(T), B, C(i64) };"#,
);
}

#[test]
fn test_extract_to_default_generic_on_fn() {
check_assist(
extract_to_default_generic,
r#"fn foo(x: $0i32$0) {}"#,
r#"fn foo<T = i32>(x: T) {}"#,
);
}

#[test]
fn test_extract_to_default_generic_const() {
check_assist(
extract_to_default_generic,
r#"type A = [i32; $08$0];"#,
r#"type A<const N: usize = 8> = [i32; N];"#,
);

check_assist(
extract_to_default_generic,
r#"type A<T> = [T; $08$0];"#,
r#"type A<T, const N: usize = 8> = [T; N];"#,
);
}

#[test]
fn test_extract_to_default_generic_const_non_array() {
check_assist(
extract_to_default_generic,
r#"
struct Foo<const N: usize>([(); N]);
type A = Foo<$08$0>;
"#,
r#"
struct Foo<const N: usize>([(); N]);
type A<const N: ${0:_} = 8> = Foo<N>;
"#,
);
}
}
2 changes: 2 additions & 0 deletions crates/ide-assists/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ mod handlers {
mod extract_function;
mod extract_module;
mod extract_struct_from_enum_variant;
mod extract_to_default_generic;
mod extract_type_alias;
mod extract_variable;
mod fix_visibility;
Expand Down Expand Up @@ -281,6 +282,7 @@ mod handlers {
extract_expressions_from_format_string::extract_expressions_from_format_string,
extract_struct_from_enum_variant::extract_struct_from_enum_variant,
extract_type_alias::extract_type_alias,
extract_to_default_generic::extract_to_default_generic,
fix_visibility::fix_visibility,
flip_binexpr::flip_binexpr,
flip_comma::flip_comma,
Expand Down
13 changes: 13 additions & 0 deletions crates/ide-assists/src/tests/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,19 @@ enum A { One(One) }
)
}

#[test]
fn doctest_extract_to_default_generic() {
check_doc_test(
"extract_to_default_generic",
r#####"
struct Foo(u32, $0String$0);
"#####,
r#####"
struct Foo<T = String>(u32, T);
"#####,
)
}

#[test]
fn doctest_extract_type_alias() {
check_doc_test(
Expand Down
17 changes: 17 additions & 0 deletions crates/syntax/src/ast/make.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,10 +1069,27 @@ pub fn type_param(name: ast::Name, bounds: Option<ast::TypeBoundList>) -> ast::T
ast_from_text(&format!("fn f<{name}{bounds}>() {{ }}"))
}

pub fn type_default_param(
name: ast::Name,
bounds: Option<ast::TypeBoundList>,
default: ast::Type,
) -> ast::TypeParam {
let bounds = bounds.map_or_else(String::new, |it| format!(": {it}"));
ast_from_text(&format!("fn f<{name}{bounds} = {default}>() {{ }}"))
}

pub fn const_param(name: ast::Name, ty: ast::Type) -> ast::ConstParam {
ast_from_text(&format!("fn f<const {name}: {ty}>() {{ }}"))
}

pub fn const_default_param(
name: ast::Name,
ty: ast::Type,
default: ast::ConstArg,
) -> ast::ConstParam {
ast_from_text(&format!("fn f<const {name}: {ty} = {default}>() {{ }}"))
}

pub fn lifetime_param(lifetime: ast::Lifetime) -> ast::LifetimeParam {
ast_from_text(&format!("fn f<{lifetime}>() {{ }}"))
}
Expand Down
1 change: 1 addition & 0 deletions crates/syntax/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ impl Iterator for AttrDocCommentIter {
}

impl<A: HasName, B: HasName> HasName for Either<A, B> {}
impl<A: HasGenericParams, B: HasGenericParams> HasGenericParams for Either<A, B> {}
Loading