Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class AWSInputVars(schema.Base):
kubernetes_version: str
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
] = "public_and_private"
eks_kms_arn: Optional[str] = None
eks_public_access_cidrs: Optional[List[str]] = ["0.0.0.0/0"]
node_groups: List[AWSNodeGroupInputVars]
Expand Down Expand Up @@ -457,7 +457,7 @@ class AmazonWebServicesProvider(schema.Base):
node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
] = "public_and_private"
eks_public_access_cidrs: Optional[List[str]] = ["0.0.0.0/0"]
eks_kms_arn: Optional[str] = None
existing_subnet_ids: Optional[List[str]] = None
Expand Down
13 changes: 7 additions & 6 deletions src/_nebari/stages/infrastructure/template/aws/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ data "aws_partition" "current" {}

locals {
# Only override_network if both existing_subnet_ids and existing_security_group_id are not null.
override_network = (var.existing_subnet_ids != null) && (var.existing_security_group_id != null)
subnet_ids = local.override_network ? var.existing_subnet_ids : module.network[0].subnet_ids
security_group_id = local.override_network ? var.existing_security_group_id : module.network[0].security_group_id
partition = data.aws_partition.current.partition
override_network = (var.existing_subnet_ids != null) && (var.existing_security_group_id != null)
private_subnet_ids = local.override_network ? var.existing_subnet_ids : module.network[0].private_subnet_ids
security_group_id = local.override_network ? var.existing_security_group_id : module.network[0].security_group_id
partition = data.aws_partition.current.partition
}

# ==================== ACCOUNTING ======================
Expand Down Expand Up @@ -50,6 +50,7 @@ module "network" {

vpc_cidr_block = var.vpc_cidr_block
aws_availability_zones = length(var.availability_zones) >= 2 ? var.availability_zones : slice(sort(data.aws_availability_zones.awszones.names), 0, 2)
region = var.region
}


Expand All @@ -70,7 +71,7 @@ module "efs" {
name = "${local.cluster_name}-jupyterhub-shared"
tags = local.additional_tags

efs_subnets = local.subnet_ids
efs_subnets = local.private_subnet_ids
efs_security_groups = [local.security_group_id]
}

Expand All @@ -88,7 +89,7 @@ module "kubernetes" {
region = var.region
kubernetes_version = var.kubernetes_version

cluster_subnets = local.subnet_ids
cluster_subnets = local.private_subnet_ids
cluster_security_groups = [local.security_group_id]

node_group_additional_policies = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,36 @@ resource "aws_vpc" "main" {
tags = merge({ Name = var.name }, var.tags, var.vpc_tags)
}

resource "aws_subnet" "main" {
resource "aws_subnet" "public" {
count = length(var.aws_availability_zones)

availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index)
vpc_id = aws_vpc.main.id
map_public_ip_on_launch = true
availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index)
vpc_id = aws_vpc.main.id

tags = merge({ Name = "${var.name}-subnet-${count.index}" }, var.tags, var.subnet_tags)
tags = merge({ Name = "${var.name}-public-subnet-${count.index}", "kubernetes.io/role/elb" = 1 }, var.tags, var.subnet_tags)

lifecycle {
ignore_changes = [
availability_zone
]
}
}

moved {
from = aws_subnet.main
to = aws_subnet.public
}


resource "aws_subnet" "private" {
count = length(var.aws_availability_zones)

availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index + length(var.aws_availability_zones))
vpc_id = aws_vpc.main.id

tags = merge({ Name = "${var.name}-private-subnet-${count.index}" }, var.tags, var.subnet_tags)

lifecycle {
ignore_changes = [
Expand All @@ -30,7 +51,25 @@ resource "aws_internet_gateway" "main" {
tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table" "main" {
resource "aws_eip" "nat-gateway-eip" {
count = length(var.aws_availability_zones)

domain = "vpc"

tags = merge({ Name = "${var.name}-nat-gateway-eip-${count.index}" }, var.tags)
}

resource "aws_nat_gateway" "main" {
count = length(var.aws_availability_zones)

allocation_id = aws_eip.nat-gateway-eip[count.index].id
subnet_id = aws_subnet.public[count.index].id

tags = merge({ Name = "${var.name}-nat-gateway-${count.index}" }, var.tags)
depends_on = [aws_internet_gateway.main]
}

resource "aws_route_table" "public" {
vpc_id = aws_vpc.main.id

route {
Expand All @@ -41,11 +80,36 @@ resource "aws_route_table" "main" {
tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table_association" "main" {
moved {
from = aws_route_table.main
to = aws_route_table.public
}

resource "aws_route_table" "private" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.main[count.index].id
route_table_id = aws_route_table.main.id
vpc_id = aws_vpc.main.id

route {
cidr_block = "0.0.0.0/0"
gateway_id = aws_nat_gateway.main[count.index].id
}

tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table_association" "public" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.public[count.index].id
route_table_id = aws_route_table.public.id
}

resource "aws_route_table_association" "private" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.private[count.index].id
route_table_id = aws_route_table.private[count.index].id
}

resource "aws_security_group" "main" {
Expand All @@ -62,7 +126,6 @@ resource "aws_security_group" "main" {
cidr_blocks = [var.vpc_cidr_block]
}

#trivy:ignore:AVD-AWS-0104
egress {
description = "Allow all ports and protocols to exit the security group"
from_port = 0
Expand All @@ -73,3 +136,61 @@ resource "aws_security_group" "main" {

tags = merge({ Name = var.name }, var.tags, var.security_group_tags)
}

resource "aws_vpc_endpoint" "s3" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.s3"
vpc_endpoint_type = "Gateway"
route_table_ids = aws_route_table.private[*].id
tags = merge({ Name = "${var.name}-s3-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "ecr_api" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.ecr.api"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-ecr-api-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "ecr_dkr" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.ecr.dkr"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-ecr-dkr-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "elasticloadbalancing" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.elasticloadbalancing"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-elb-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "sts" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.sts"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-sts-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "eks" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.eks"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-eks-endpoint" }, var.tags)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ output "security_group_id" {
value = aws_security_group.main.id
}

output "subnet_ids" {
description = "AWS VPC subnet ids"
value = aws_subnet.main[*].id
output "public_subnet_ids" {
description = "AWS VPC public subnet ids"
value = aws_subnet.public[*].id
}

output "private_subnet_ids" {
description = "AWS VPC private subnet ids"
value = aws_subnet.private[*].id
}

output "vpc_id" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@ variable "vpc_cidr_block" {
variable "vpc_cidr_newbits" {
description = "VPC cidr number of bits to support 2^N subnets"
type = number
default = 2
default = 2 # allows 4 /18 subnets with 16382 addresses each
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default = 2 # allows 4 /18 subnets with 16382 addresses each
default = 3 # allows 8 /18 subnets with 16382 addresses each

needed this for my use case with 3 subnets specified

}

variable "region" {
description = "AWS region to operate infrastructure"
type = string

}
3 changes: 2 additions & 1 deletion src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def check_immutable_fields(self):
# Return a default (mutable) extra field schema if bottom level is not a Pydantic model (such as a free-form 'overrides' block)
if isinstance(bottom_level_schema, BaseModel):
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
**type(bottom_level_schema).model_fields[keys[-1]].json_schema_extra
or {}
)
else:
extra_field_schema = schema.ExtraFieldSchema()
Expand Down
Loading