98
98
get_log_group_deletion_policy ,
99
99
get_shared_storage_ids_by_type ,
100
100
get_slurm_specific_dna_json_for_head_node ,
101
+ get_source_ingress_rule ,
101
102
get_user_data_content ,
102
103
to_comma_separated_string ,
103
104
)
@@ -193,9 +194,10 @@ def _get_login_security_groups(self):
193
194
if isinstance (self .config , SlurmClusterConfig ) and self .config .login_nodes
194
195
else []
195
196
)
196
-
197
+ # Add the managed login node security groups
197
198
if self ._login_security_group :
198
- login_security_groups .append (self ._login_security_group .ref )
199
+ for _ , managed_security_group in self ._login_security_group .items ():
200
+ login_security_groups .append (managed_security_group .ref )
199
201
200
202
return login_security_groups
201
203
@@ -467,7 +469,7 @@ def _add_login_nodes_resources(self):
467
469
shared_storage_infos = self .shared_storage_infos ,
468
470
shared_storage_mount_dirs = self .shared_storage_mount_dirs ,
469
471
shared_storage_attributes = self .shared_storage_attributes ,
470
- login_security_group = self ._login_security_group ,
472
+ login_security_groups = self ._login_security_group ,
471
473
head_eni = self ._head_eni ,
472
474
cluster_hosted_zone = self .scheduler_resources .cluster_hosted_zone if self .scheduler_resources else None ,
473
475
cluster_bucket = self .bucket ,
@@ -574,11 +576,8 @@ def _add_head_eni(self):
574
576
575
577
def _add_security_groups (self ):
576
578
head_node_security_groups , managed_head_security_group = self ._head_security_groups ()
577
- (
578
- login_security_groups ,
579
- managed_login_security_group ,
580
- custom_login_security_groups ,
581
- ) = self ._login_security_groups ()
579
+ login_security_groups , managed_login_security_groups = self ._login_security_groups ()
580
+
582
581
(
583
582
compute_security_groups ,
584
583
managed_compute_security_group ,
@@ -590,13 +589,12 @@ def _add_security_groups(self):
590
589
custom_compute_security_groups ,
591
590
head_node_security_groups ,
592
591
login_security_groups ,
593
- custom_login_security_groups ,
594
592
managed_compute_security_group ,
595
593
managed_head_security_group ,
596
- managed_login_security_group ,
594
+ managed_login_security_groups ,
597
595
)
598
596
599
- return managed_head_security_group , managed_compute_security_group , managed_login_security_group
597
+ return managed_head_security_group , managed_compute_security_group , managed_login_security_groups
600
598
601
599
def _head_security_groups (self ):
602
600
managed_head_security_group = None
@@ -609,7 +607,7 @@ def _head_security_groups(self):
609
607
return head_node_security_groups , managed_head_security_group
610
608
611
609
def _login_security_groups (self ):
612
- managed_login_security_group = None
610
+ managed_login_security_groups = dict ()
613
611
custom_login_security_groups = set ()
614
612
managed_login_security_group_required = False
615
613
if self ._condition_is_slurm () and self .config .login_nodes :
@@ -622,9 +620,11 @@ def _login_security_groups(self):
622
620
managed_login_security_group_required = True
623
621
login_security_groups = list (custom_login_security_groups )
624
622
if managed_login_security_group_required :
625
- managed_login_security_group = self ._add_login_nodes_security_group ()
626
- login_security_groups .append (managed_login_security_group .ref )
627
- return login_security_groups , managed_login_security_group , custom_login_security_groups
623
+ managed_login_security_groups = self ._add_login_nodes_security_group ()
624
+ login_security_groups .extend (
625
+ [security_group .ref for security_group in managed_login_security_groups .values ()]
626
+ )
627
+ return login_security_groups , managed_login_security_groups
628
628
629
629
def _compute_security_groups (self ):
630
630
managed_compute_security_group = None
@@ -649,20 +649,18 @@ def _add_inbounds_to_managed_security_groups(
649
649
custom_compute_security_groups ,
650
650
head_node_security_groups ,
651
651
login_security_groups ,
652
- custom_login_security_groups ,
653
652
managed_compute_security_group ,
654
653
managed_head_security_group ,
655
- managed_login_security_group ,
654
+ managed_login_security_groups ,
656
655
):
657
656
self ._add_inbounds_to_managed_head_security_group (
658
657
compute_security_groups , login_security_groups , managed_head_security_group
659
658
)
660
659
661
- self ._add_inbounds_to_managed_login_security_group (
660
+ self ._add_inbounds_to_managed_login_security_groups (
662
661
head_node_security_groups ,
663
662
compute_security_groups ,
664
- custom_login_security_groups ,
665
- managed_login_security_group ,
663
+ managed_login_security_groups ,
666
664
)
667
665
668
666
self ._add_inbounds_to_managed_compute_security_group (
@@ -698,29 +696,28 @@ def _add_inbounds_to_managed_head_security_group(
698
696
port = NFS_PORT ,
699
697
)
700
698
701
- def _add_inbounds_to_managed_login_security_group (
699
+ def _add_inbounds_to_managed_login_security_groups (
702
700
self ,
703
701
head_node_security_groups ,
704
702
compute_security_groups ,
705
- custom_login_security_groups ,
706
- managed_login_security_group ,
703
+ managed_login_security_groups ,
707
704
):
708
- if managed_login_security_group :
709
- # Access to login nodes from head node and compute nodes
710
- for index , security_group in enumerate ( head_node_security_groups ):
711
- self . _allow_all_ingress (
712
- f"LoginSecurityGroupHeadNodeIngress { index } " , security_group , managed_login_security_group . ref
713
- )
714
- for index , security_group in enumerate ( compute_security_groups ):
715
- self . _allow_all_ingress (
716
- f"LoginSecurityGroupComputeIngress { index } " , security_group , managed_login_security_group . ref
717
- )
718
- for index , security_group in enumerate (custom_login_security_groups ):
719
- self ._allow_all_ingress (
720
- f"LoginSecurityGroupCustomLoginSecurityGroupIngress { index } " ,
721
- security_group ,
722
- managed_login_security_group .ref ,
723
- )
705
+ if managed_login_security_groups :
706
+ for pool_name , managed_security_group in managed_login_security_groups . items ():
707
+ # Access to login nodes from head node
708
+ for index , security_group in enumerate ( head_node_security_groups ):
709
+ self . _allow_all_ingress (
710
+ f" { pool_name } LoginSecurityGroupHeadNodeIngress { index } " ,
711
+ security_group ,
712
+ managed_security_group . ref ,
713
+ )
714
+ # Access to login nodes from compute nodes
715
+ for index , security_group in enumerate (compute_security_groups ):
716
+ self ._allow_all_ingress (
717
+ f" { pool_name } LoginSecurityGroupComputeIngress { index } " ,
718
+ security_group ,
719
+ managed_security_group .ref ,
720
+ )
724
721
725
722
def _add_inbounds_to_managed_compute_security_group (
726
723
self ,
@@ -878,44 +875,41 @@ def _add_compute_security_group(self):
878
875
879
876
return compute_security_group
880
877
881
- def _get_source_ingress_rule (self , setting ):
882
- if setting .startswith ("pl" ):
883
- return ec2 .CfnSecurityGroup .IngressProperty (
884
- ip_protocol = "tcp" , from_port = 22 , to_port = 22 , source_prefix_list_id = setting
885
- )
886
- else :
887
- return ec2 .CfnSecurityGroup .IngressProperty (ip_protocol = "tcp" , from_port = 22 , to_port = 22 , cidr_ip = setting )
888
-
889
878
def _add_login_nodes_security_group (self ):
890
- # TODO review this once we allow more pools to be defined in the LoginNodes section
891
- login_nodes_security_group_ingress = [
892
- # SSH access
893
- self ._get_source_ingress_rule (self .config .login_nodes .pools [0 ].ssh .allowed_ips )
894
- ]
895
-
896
- if self .config .login_nodes .has_dcv_enabled :
897
- login_nodes_security_group_ingress .append (
898
- # DCV access
899
- ec2 .CfnSecurityGroup .IngressProperty (
900
- ip_protocol = "tcp" ,
901
- from_port = self .config .login_nodes .pools [0 ].dcv .port ,
902
- to_port = self .config .login_nodes .pools [0 ].dcv .port ,
903
- cidr_ip = self .config .login_nodes .pools [0 ].dcv .allowed_ips ,
879
+ """Return a dictionary mapping each login node pool name to its respective managed security group."""
880
+ pool_to_managed_security_group_dict = dict ()
881
+
882
+ for pool in self .config .login_nodes .pools :
883
+ # Check if the pool has user-defined security groups
884
+ if not pool .networking .security_groups :
885
+ security_group_ingress_rules = [
886
+ # Add rule for SSH access
887
+ get_source_ingress_rule (pool .ssh .allowed_ips )
888
+ ]
889
+ # Add rule for DCV access if enabled
890
+ if pool .has_dcv_enabled :
891
+ security_group_ingress_rules .append (
892
+ ec2 .CfnSecurityGroup .IngressProperty (
893
+ ip_protocol = "tcp" ,
894
+ from_port = pool .dcv .port ,
895
+ to_port = pool .dcv .port ,
896
+ cidr_ip = pool .dcv .allowed_ips ,
897
+ )
898
+ )
899
+ pool_to_managed_security_group_dict [pool .name ] = ec2 .CfnSecurityGroup (
900
+ self .stack ,
901
+ f"{ pool .name } LoginNodesSecurityGroup" ,
902
+ group_description = "Enable access to the login nodes" ,
903
+ vpc_id = self .config .vpc_id ,
904
+ security_group_ingress = security_group_ingress_rules ,
904
905
)
905
- )
906
906
907
- return ec2 .CfnSecurityGroup (
908
- self .stack ,
909
- "LoginNodesSecurityGroup" ,
910
- group_description = "Enable access to the login nodes" ,
911
- vpc_id = self .config .vpc_id ,
912
- security_group_ingress = login_nodes_security_group_ingress ,
913
- )
907
+ return pool_to_managed_security_group_dict
914
908
915
909
def _add_head_security_group (self ):
916
910
head_security_group_ingress = [
917
911
# SSH access
918
- self . _get_source_ingress_rule (self .config .head_node .ssh .allowed_ips )
912
+ get_source_ingress_rule (self .config .head_node .ssh .allowed_ips )
919
913
]
920
914
921
915
if self .config .is_dcv_enabled :
0 commit comments