Skip to content

Commit 41563fd

Browse files
committed
Infer variants through type aliased enums
1 parent e6a1c9c commit 41563fd

File tree

2 files changed

+70
-33
lines changed

2 files changed

+70
-33
lines changed

crates/hir_ty/src/infer.rs

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -484,36 +484,13 @@ impl<'a> InferenceContext<'a> {
484484
let generics = crate::utils::generics(self.db.upcast(), impl_id.into());
485485
let substs = generics.type_params_subst(self.db);
486486
let ty = self.db.impl_self_ty(impl_id).substitute(&Interner, &substs);
487-
match unresolved {
488-
None => {
489-
let variant = ty_variant(&ty);
490-
(ty, variant)
491-
}
492-
Some(1) => {
493-
let segment = path.mod_path().segments().last().unwrap();
494-
// this could be an enum variant or associated type
495-
if let Some((AdtId::EnumId(enum_id), _)) = ty.as_adt() {
496-
let enum_data = self.db.enum_data(enum_id);
497-
if let Some(local_id) = enum_data.variant(segment) {
498-
let variant = EnumVariantId { parent: enum_id, local_id };
499-
return (ty, Some(variant.into()));
500-
}
501-
}
502-
// FIXME potentially resolve assoc type
503-
(self.err_ty(), None)
504-
}
505-
Some(_) => {
506-
// FIXME diagnostic
507-
(self.err_ty(), None)
508-
}
509-
}
487+
self.resolve_variant_on_alias(ty, unresolved, path)
510488
}
511489
TypeNs::TypeAliasId(it) => {
512490
let ty = TyBuilder::def_ty(self.db, it.into())
513491
.fill(std::iter::repeat_with(|| self.table.new_type_var()))
514492
.build();
515-
let variant = ty_variant(&ty);
516-
forbid_unresolved_segments((ty, variant), unresolved)
493+
self.resolve_variant_on_alias(ty, unresolved, path)
517494
}
518495
TypeNs::AdtSelfType(_) => {
519496
// FIXME this could happen in array size expressions, once we're checking them
@@ -540,16 +517,43 @@ impl<'a> InferenceContext<'a> {
540517
(TyKind::Error.intern(&Interner), None)
541518
}
542519
}
520+
}
543521

544-
fn ty_variant(ty: &Ty) -> Option<VariantId> {
545-
ty.as_adt().and_then(|(adt_id, _)| match adt_id {
546-
AdtId::StructId(s) => Some(VariantId::StructId(s)),
547-
AdtId::UnionId(u) => Some(VariantId::UnionId(u)),
548-
AdtId::EnumId(_) => {
549-
// FIXME Error E0071, expected struct, variant or union type, found enum `Foo`
550-
None
522+
fn resolve_variant_on_alias(
523+
&mut self,
524+
ty: Ty,
525+
unresolved: Option<usize>,
526+
path: &Path,
527+
) -> (Ty, Option<VariantId>) {
528+
match unresolved {
529+
None => {
530+
let variant = ty.as_adt().and_then(|(adt_id, _)| match adt_id {
531+
AdtId::StructId(s) => Some(VariantId::StructId(s)),
532+
AdtId::UnionId(u) => Some(VariantId::UnionId(u)),
533+
AdtId::EnumId(_) => {
534+
// FIXME Error E0071, expected struct, variant or union type, found enum `Foo`
535+
None
536+
}
537+
});
538+
(ty, variant)
539+
}
540+
Some(1) => {
541+
let segment = path.mod_path().segments().last().unwrap();
542+
// this could be an enum variant or associated type
543+
if let Some((AdtId::EnumId(enum_id), _)) = ty.as_adt() {
544+
let enum_data = self.db.enum_data(enum_id);
545+
if let Some(local_id) = enum_data.variant(segment) {
546+
let variant = EnumVariantId { parent: enum_id, local_id };
547+
return (ty, Some(variant.into()));
548+
}
551549
}
552-
})
550+
// FIXME potentially resolve assoc type
551+
(self.err_ty(), None)
552+
}
553+
Some(_) => {
554+
// FIXME diagnostic
555+
(self.err_ty(), None)
556+
}
553557
}
554558
}
555559

crates/hir_ty/src/tests/simple.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,3 +2564,36 @@ fn f() {
25642564
"#,
25652565
)
25662566
}
2567+
2568+
#[test]
2569+
fn infer_type_alias_variant() {
2570+
check_infer(
2571+
r#"
2572+
type Qux = Foo;
2573+
enum Foo {
2574+
Bar(i32),
2575+
Baz { baz: f32 }
2576+
}
2577+
2578+
fn f() {
2579+
match Foo::Bar(3) {
2580+
Qux::Bar(bar) => (),
2581+
Qux::Baz { baz } => (),
2582+
}
2583+
}
2584+
"#,
2585+
expect![[r#"
2586+
72..166 '{ ... } }': ()
2587+
78..164 'match ... }': ()
2588+
84..92 'Foo::Bar': Bar(i32) -> Foo
2589+
84..95 'Foo::Bar(3)': Foo
2590+
93..94 '3': i32
2591+
106..119 'Qux::Bar(bar)': Foo
2592+
115..118 'bar': i32
2593+
123..125 '()': ()
2594+
135..151 'Qux::B... baz }': Foo
2595+
146..149 'baz': f32
2596+
155..157 '()': ()
2597+
"#]],
2598+
)
2599+
}

0 commit comments

Comments
 (0)