@@ -636,6 +636,154 @@ impl<'a> TypeCheck<'a> {
636
636
637
637
fn check_expr_if ( & mut self , node : & ast:: ExprIfType , expected_ty : SourceType ) -> SourceType {
638
638
self . check_stmt_let ( & node. cond ) ;
639
+ let cond = node. cond . expr . as_ref ( ) . unwrap ( ) ;
640
+ let expr_type = self . analysis . ty ( cond. id ( ) ) ;
641
+ let expr_enum_id = expr_type. enum_id ( ) ;
642
+
643
+ if expr_enum_id. is_some ( ) {
644
+ self . check_expr_if_enum ( node, expected_ty)
645
+ } else {
646
+ self . check_expr_if_union ( node, expected_ty)
647
+ }
648
+ }
649
+
650
+ fn check_expr_if_union ( & mut self , node : & ast:: ExprIfType , expected_ty : SourceType ) -> SourceType {
651
+ let cond = node. cond . expr . as_ref ( ) . unwrap ( ) ;
652
+ let expr_type = self . analysis . ty ( cond. id ( ) ) ;
653
+ let mut result_type = SourceType :: Error ;
654
+
655
+ let expr_union_id = expr_type. union_id ( ) ;
656
+ let expr_type_params = expr_type. type_params ( ) ;
657
+
658
+ let union_variants = if let Some ( expr_union_id) = expr_union_id {
659
+ let union_ = self . sa . unions [ expr_union_id] . read ( ) ;
660
+ union_. variants . len ( )
661
+ } else {
662
+ 0
663
+ } ;
664
+
665
+ let mut used_variants = FixedBitSet :: with_capacity ( union_variants) ;
666
+ let mut non_variant_cases = false ;
667
+
668
+ for case in & node. cases {
669
+ self . symtable . push_level ( ) ;
670
+
671
+ match & case. data {
672
+ ast:: IfCaseData :: Simple => {
673
+ self . check_if_condition_is_bool ( expr_type. clone ( ) , cond) ;
674
+ non_variant_cases = true ;
675
+ }
676
+
677
+ ast:: IfCaseData :: Continuation ( continuation) => {
678
+ let cont_type = self . check_expr ( & continuation, expected_ty. clone ( ) ) ;
679
+ self . check_if_condition_is_bool ( cont_type, continuation) ;
680
+ non_variant_cases = true ;
681
+ }
682
+
683
+ ast:: IfCaseData :: Patterns ( patterns) => {
684
+ debug_assert_eq ! ( patterns. len( ) , 1 ) ;
685
+ if !expr_type. is_union ( ) {
686
+ self . sa . diag . lock ( ) . report (
687
+ self . file_id ,
688
+ node. pos ,
689
+ ErrorMessage :: UnionExpected ,
690
+ ) ;
691
+ }
692
+ let pattern = patterns. first ( ) . expect ( "no pattern" ) ;
693
+ let sym = self . read_path ( & pattern. path ) ;
694
+
695
+ let mut used_idents: HashSet < Name > = HashSet :: new ( ) ;
696
+
697
+ match sym {
698
+ Ok ( Sym :: Class ( class_id) ) => {
699
+ let union = self . sa . unions . idx ( expr_type. union_id ( ) . unwrap ( ) ) ;
700
+ let union = union. read ( ) ;
701
+ let found = union. variants . iter ( ) . find ( |v| v. type_ . cls_id ( ) . unwrap ( ) == class_id) ;
702
+ if !found. is_some ( ) {
703
+ panic ! ( )
704
+ }
705
+ used_variants. insert ( found. unwrap ( ) . id ) ;
706
+ if cond. is_ident ( ) {
707
+ let tpe = found. unwrap ( ) . clone ( ) . type_ ;
708
+ let var_id = self . vars . add_var ( cond. to_ident ( ) . unwrap ( ) . name , tpe, false ) ;
709
+ self . add_local ( var_id, cond. pos ( ) ) ;
710
+ self . analysis
711
+ . map_vars
712
+ . insert ( cond. id ( ) , self . vars . local_var_id ( var_id) ) ;
713
+ }
714
+ }
715
+
716
+ Ok ( Sym :: Value ( value_id) ) => {
717
+ unimplemented ! ( )
718
+ }
719
+
720
+ Ok ( _) => {
721
+ let msg = ErrorMessage :: UnionVariantExpected ;
722
+ self . sa . diag . lock ( ) . report ( self . file_id , node. pos , msg) ;
723
+ }
724
+
725
+ Err ( ( ) ) => { }
726
+ }
727
+ }
728
+ }
729
+
730
+ let case_ty = self . check_expr ( & case. value , expected_ty. clone ( ) ) ;
731
+
732
+ if result_type. is_error ( ) {
733
+ result_type = case_ty;
734
+ } else if case_ty. is_error ( ) {
735
+ // ignore this case
736
+ } else if !result_type. allows ( self . sa , case_ty. clone ( ) ) {
737
+ let result_type_name = result_type. name_fct ( self . sa , self . fct ) ;
738
+ let case_ty_name = case_ty. name_fct ( self . sa , self . fct ) ;
739
+ let msg = ErrorMessage :: IfBranchTypesIncompatible ( result_type_name, case_ty_name) ;
740
+ self . sa
741
+ . diag
742
+ . lock ( )
743
+ . report ( self . file_id , case. value . pos ( ) , msg) ;
744
+ }
745
+
746
+ self . symtable . pop_level ( ) ;
747
+ }
748
+
749
+ used_variants. toggle_range ( ..) ;
750
+ let is_exhaustive = used_variants. count_ones ( ..) == 0 ;
751
+ if !is_exhaustive && node. else_block . is_none ( ) {
752
+ let msg = ErrorMessage :: IfPatternVariantUncovered ;
753
+ self . sa . diag . lock ( ) . report ( self . file_id , node. pos , msg) ;
754
+ } else if let Some ( else_block) = & node. else_block {
755
+ if is_exhaustive && !non_variant_cases {
756
+ let msg = ErrorMessage :: IfPatternUnreachable ;
757
+ self . sa
758
+ . diag
759
+ . lock ( )
760
+ . report ( self . file_id , else_block. pos ( ) , msg)
761
+ } else {
762
+ let else_ty = self . check_expr ( & else_block, expected_ty. clone ( ) ) ;
763
+
764
+ if result_type. is_error ( ) {
765
+ result_type = else_ty;
766
+ } else if else_ty. is_error ( ) {
767
+ // ignore this case
768
+ } else if !result_type. allows ( self . sa , else_ty. clone ( ) ) {
769
+ let result_type_name = result_type. name_fct ( self . sa , self . fct ) ;
770
+ let else_ty_name = else_ty. name_fct ( self . sa , self . fct ) ;
771
+ let msg =
772
+ ErrorMessage :: IfBranchTypesIncompatible ( result_type_name, else_ty_name) ;
773
+ self . sa
774
+ . diag
775
+ . lock ( )
776
+ . report ( self . file_id , else_block. pos ( ) , msg) ;
777
+ }
778
+ }
779
+ }
780
+
781
+ self . analysis . set_ty ( node. id , result_type. clone ( ) ) ;
782
+
783
+ result_type
784
+ }
785
+
786
+ fn check_expr_if_enum ( & mut self , node : & ast:: ExprIfType , expected_ty : SourceType ) -> SourceType {
639
787
let cond = node. cond . expr . as_ref ( ) . unwrap ( ) ;
640
788
let expr_type = self . analysis . ty ( cond. id ( ) ) ;
641
789
let mut result_type = SourceType :: Error ;
@@ -669,7 +817,7 @@ impl<'a> TypeCheck<'a> {
669
817
}
670
818
ast:: IfCaseData :: Patterns ( patterns) => {
671
819
debug_assert_eq ! ( patterns. len( ) , 1 ) ;
672
- if !expr_type. is_enum ( ) && !expr_type . is_union ( ) {
820
+ if !expr_type. is_enum ( ) {
673
821
self . sa . diag . lock ( ) . report (
674
822
self . file_id ,
675
823
node. pos ,
0 commit comments