Skip to content

feat: custom domain name support for private endpoints #3719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cfnlintrc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ ignore_templates:
- tests/translator/output/**/managed_policies_everything.json # intentionally contains wrong arns
- tests/translator/output/**/function_with_provisioned_poller_config.json
- tests/translator/output/**/function_with_metrics_config.json
- tests/translator/output/**/api_with_custom_domains_private.json

ignore_checks:
- E2531 # Deprecated runtime; not relevant for transform tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,16 @@ class Route53(BaseModel):
class Domain(BaseModel):
BasePath: Optional[PassThroughProp] = domain("BasePath")
NormalizeBasePath: Optional[bool] = domain("NormalizeBasePath")
Policy: Optional[PassThroughProp]
CertificateArn: PassThroughProp = domain("CertificateArn")
DomainName: PassThroughProp = passthrough_prop(
DOMAIN_STEM,
"DomainName",
["AWS::ApiGateway::DomainName", "Properties", "DomainName"],
)
EndpointConfiguration: Optional[SamIntrinsicable[Literal["REGIONAL", "EDGE"]]] = domain("EndpointConfiguration")
EndpointConfiguration: Optional[SamIntrinsicable[Literal["REGIONAL", "EDGE", "PRIVATE"]]] = domain(
"EndpointConfiguration"
)
MutualTlsAuthentication: Optional[PassThroughProp] = passthrough_prop(
DOMAIN_STEM,
"MutualTlsAuthentication",
Expand Down
209 changes: 186 additions & 23 deletions samtranslator/model/api/api_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
ApiGatewayApiKey,
ApiGatewayAuthorizer,
ApiGatewayBasePathMapping,
ApiGatewayBasePathMappingV2,
ApiGatewayDeployment,
ApiGatewayDomainName,
ApiGatewayDomainNameV2,
ApiGatewayResponse,
ApiGatewayRestApi,
ApiGatewayStage,
Expand Down Expand Up @@ -79,6 +81,13 @@ class ApiDomainResponse:
recordset_group: Any


@dataclass
class ApiDomainResponseV2:
domain: Optional[ApiGatewayDomainNameV2]
apigw_basepath_mapping_list: Optional[List[ApiGatewayBasePathMappingV2]]
recordset_group: Any


class SharedApiUsagePlan:
"""
Collects API information from different API resources in the same template,
Expand Down Expand Up @@ -517,11 +526,7 @@ def _construct_api_domain( # noqa: PLR0912, PLR0915
if mutual_tls_auth.get("TruststoreVersion", None):
domain.MutualTlsAuthentication["TruststoreVersion"] = mutual_tls_auth["TruststoreVersion"]

if self.domain.get("SecurityPolicy", None):
domain.SecurityPolicy = self.domain["SecurityPolicy"]

if self.domain.get("OwnershipVerificationCertificateArn", None):
domain.OwnershipVerificationCertificateArn = self.domain["OwnershipVerificationCertificateArn"]
self._set_optional_domain_properties(domain)

basepaths: Optional[List[str]]
basepath_value = self.domain.get("BasePath")
Expand All @@ -539,12 +544,102 @@ def _construct_api_domain( # noqa: PLR0912, PLR0915
basepath_resource_list: List[ApiGatewayBasePathMapping] = []

if basepaths is None:
basepath_mapping = ApiGatewayBasePathMapping(
self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes
basepath_mapping = self._create_basepath_mapping(api_domain_name, rest_api, None, None)
basepath_resource_list.extend([basepath_mapping])
else:
sam_expect(basepaths, self.logical_id, "Domain.BasePath").to_be_a_list_of(ExpectedType.STRING)
for basepath in basepaths:
# Remove possible leading and trailing '/' because a base path may only
# contain letters, numbers, and one of "$-_.+!*'()"
path = "".join(e for e in basepath if e.isalnum())
mapping_basepath = path if normalize_basepath else basepath
logical_id = "{}{}{}".format(self.logical_id, path, "BasePathMapping")
basepath_mapping = self._create_basepath_mapping(
api_domain_name, rest_api, logical_id, mapping_basepath
)
basepath_resource_list.extend([basepath_mapping])

# Create the Route53 RecordSetGroup resource
record_set_group = None
route53 = self.domain.get("Route53")
if route53 is not None:
sam_expect(route53, self.logical_id, "Domain.Route53").to_be_a_map()
if route53.get("HostedZoneId") is None and route53.get("HostedZoneName") is None:
raise InvalidResourceException(
self.logical_id,
"HostedZoneId or HostedZoneName is required to enable Route53 support on Custom Domains.",
)

logical_id_suffix = LogicalIdGenerator(
"", route53.get("HostedZoneId") or route53.get("HostedZoneName")
).gen()
logical_id = "RecordSetGroup" + logical_id_suffix

record_set_group = route53_record_set_groups.get(logical_id)

if route53.get("SeparateRecordSetGroup"):
sam_expect(
route53.get("SeparateRecordSetGroup"), self.logical_id, "Domain.Route53.SeparateRecordSetGroup"
).to_be_a_bool()
return ApiDomainResponse(
domain,
basepath_resource_list,
self._construct_single_record_set_group(self.domain, api_domain_name, route53),
)

if not record_set_group:
record_set_group = self._get_record_set_group(logical_id, route53)
route53_record_set_groups[logical_id] = record_set_group

record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, api_domain_name, route53)

return ApiDomainResponse(domain, basepath_resource_list, record_set_group)

def _construct_api_domain_v2(
self, rest_api: ApiGatewayRestApi, route53_record_set_groups: Any
) -> ApiDomainResponseV2:
"""
Constructs and returns the ApiGateway Domain V2 and BasepathMapping V2
"""
if self.domain is None:
return ApiDomainResponseV2(None, None, None)

sam_expect(self.domain, self.logical_id, "Domain").to_be_a_map()
domain_name: PassThrough = sam_expect(
self.domain.get("DomainName"), self.logical_id, "Domain.DomainName"
).to_not_be_none()
certificate_arn: PassThrough = sam_expect(
self.domain.get("CertificateArn"), self.logical_id, "Domain.CertificateArn"
).to_not_be_none()

api_domain_name = "{}{}".format("ApiGatewayDomainNameV2", LogicalIdGenerator("", domain_name).gen())
domain_name_arn = ref(api_domain_name)
domain = ApiGatewayDomainNameV2(api_domain_name, attributes=self.passthrough_resource_attributes)

domain.DomainName = domain_name
endpoint = self.domain.get("EndpointConfiguration")

if endpoint not in ["EDGE", "REGIONAL", "PRIVATE"]:
raise InvalidResourceException(
self.logical_id,
"EndpointConfiguration for Custom Domains must be"
" one of {}.".format(["EDGE", "REGIONAL", "PRIVATE"]),
)
basepath_mapping.DomainName = ref(api_domain_name)
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")

domain.CertificateArn = certificate_arn

domain.EndpointConfiguration = {"Types": [endpoint]}

self._set_optional_domain_properties(domain)

basepaths: Optional[List[str]] = self._get_basepaths()

# Boolean to allow/disallow symbols in BasePath property
normalize_basepath = self.domain.get("NormalizeBasePath", True)

basepath_resource_list: List[ApiGatewayBasePathMappingV2] = []
if basepaths is None:
basepath_mapping = self._create_basepath_mapping_v2(domain_name_arn, rest_api)
basepath_resource_list.extend([basepath_mapping])
else:
sam_expect(basepaths, self.logical_id, "Domain.BasePath").to_be_a_list_of(ExpectedType.STRING)
Expand All @@ -553,10 +648,10 @@ def _construct_api_domain( # noqa: PLR0912, PLR0915
# contain letters, numbers, and one of "$-_.+!*'()"
path = "".join(e for e in basepath if e.isalnum())
logical_id = "{}{}{}".format(self.logical_id, path, "BasePathMapping")
basepath_mapping = ApiGatewayBasePathMapping(
basepath_mapping = ApiGatewayBasePathMappingV2(
logical_id, attributes=self.passthrough_resource_attributes
)
basepath_mapping.DomainName = ref(api_domain_name)
basepath_mapping.DomainNameArn = domain_name_arn
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
basepath_mapping.BasePath = path if normalize_basepath else basepath
Expand Down Expand Up @@ -584,24 +679,48 @@ def _construct_api_domain( # noqa: PLR0912, PLR0915
sam_expect(
route53.get("SeparateRecordSetGroup"), self.logical_id, "Domain.Route53.SeparateRecordSetGroup"
).to_be_a_bool()
return ApiDomainResponse(
return ApiDomainResponseV2(
domain,
basepath_resource_list,
self._construct_single_record_set_group(self.domain, api_domain_name, route53),
self._construct_single_record_set_group(self.domain, domain_name, route53),
)

if not record_set_group:
record_set_group = Route53RecordSetGroup(logical_id, attributes=self.passthrough_resource_attributes)
if "HostedZoneId" in route53:
record_set_group.HostedZoneId = route53.get("HostedZoneId")
if "HostedZoneName" in route53:
record_set_group.HostedZoneName = route53.get("HostedZoneName")
record_set_group.RecordSets = []
record_set_group = self._get_record_set_group(logical_id, route53)
route53_record_set_groups[logical_id] = record_set_group

record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, api_domain_name, route53)
record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, domain_name, route53)

return ApiDomainResponse(domain, basepath_resource_list, record_set_group)
return ApiDomainResponseV2(domain, basepath_resource_list, record_set_group)

def _get_basepaths(self) -> Optional[List[str]]:
if self.domain is None:
return None
basepath_value = self.domain.get("BasePath")
if self.domain.get("BasePath") and isinstance(basepath_value, str):
return [basepath_value]
if self.domain.get("BasePath") and isinstance(basepath_value, list):
return cast(Optional[List[Any]], basepath_value)
return None

def _set_optional_domain_properties(self, domain: Union[ApiGatewayDomainName, ApiGatewayDomainNameV2]) -> None:
if self.domain is None:
return
if self.domain.get("SecurityPolicy", None):
domain.SecurityPolicy = self.domain["SecurityPolicy"]
if self.domain.get("Policy", None):
domain.Policy = self.domain["Policy"]
if self.domain.get("OwnershipVerificationCertificateArn", None):
domain.OwnershipVerificationCertificateArn = self.domain["OwnershipVerificationCertificateArn"]

def _get_record_set_group(self, logical_id: str, route53: Dict[str, Any]) -> Route53RecordSetGroup:
record_set_group = Route53RecordSetGroup(logical_id, attributes=self.passthrough_resource_attributes)
if "HostedZoneId" in route53:
record_set_group.HostedZoneId = route53.get("HostedZoneId")
if "HostedZoneName" in route53:
record_set_group.HostedZoneName = route53.get("HostedZoneName")
record_set_group.RecordSets = []
return record_set_group

def _construct_single_record_set_group(
self, domain: Dict[str, Any], api_domain_name: str, route53: Any
Expand Down Expand Up @@ -667,6 +786,40 @@ def _construct_alias_target(self, domain: Dict[str, Any], api_domain_name: str,
alias_target["DNSName"] = route53.get("DistributionDomainName")
return alias_target

def _create_basepath_mapping(
self,
api_domain_name: PassThrough,
rest_api: ApiGatewayRestApi,
logical_id: Optional[str],
basepath: Optional[str],
) -> ApiGatewayBasePathMapping:

basepath_mapping: ApiGatewayBasePathMapping
basepath_mapping = (
ApiGatewayBasePathMapping(logical_id, attributes=self.passthrough_resource_attributes)
if logical_id
else ApiGatewayBasePathMapping(
self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes
)
)
basepath_mapping.DomainName = ref(api_domain_name)
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
if basepath:
basepath_mapping.BasePath = basepath
return basepath_mapping

def _create_basepath_mapping_v2(
self, domain_name_arn: PassThrough, rest_api: ApiGatewayRestApi
) -> ApiGatewayBasePathMappingV2:
basepath_mapping = ApiGatewayBasePathMappingV2(
self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes
)
basepath_mapping.DomainNameArn = domain_name_arn
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
return basepath_mapping

@cw_timer(prefix="Generator", name="Api")
def to_cloudformation(
self, redeploy_restapi_parameters: Optional[Any], route53_record_set_groups: Dict[str, Route53RecordSetGroup]
Expand All @@ -676,10 +829,19 @@ def to_cloudformation(
:returns: a tuple containing the RestApi, Deployment, and Stage for an empty Api.
:rtype: tuple
"""
api_domain_response: Union[ApiDomainResponseV2, ApiDomainResponse]
domain: Union[Resource, None]
basepath_mapping: Union[List[ApiGatewayBasePathMapping], List[ApiGatewayBasePathMappingV2], None]
rest_api = self._construct_rest_api()
api_domain_response = self._construct_api_domain(rest_api, route53_record_set_groups)
api_domain_response = (
self._construct_api_domain_v2(rest_api, route53_record_set_groups)
if isinstance(self.domain, dict) and self.domain.get("EndpointConfiguration") == "PRIVATE"
else self._construct_api_domain(rest_api, route53_record_set_groups)
)

domain = api_domain_response.domain
basepath_mapping = api_domain_response.apigw_basepath_mapping_list

route53_recordsetGroup = api_domain_response.recordset_group

deployment = self._construct_deployment(rest_api)
Expand All @@ -703,6 +865,7 @@ def to_cloudformation(
Tuple[Resource],
List[LambdaPermission],
List[ApiGatewayBasePathMapping],
List[ApiGatewayBasePathMappingV2],
],
] = []

Expand Down
29 changes: 29 additions & 0 deletions samtranslator/model/apigateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ class ApiGatewayDomainName(Resource):
OwnershipVerificationCertificateArn: Optional[PassThrough]


class ApiGatewayDomainNameV2(Resource):
resource_type = "AWS::ApiGateway::DomainNameV2"
property_types = {
"DomainName": GeneratedProperty(),
"EndpointConfiguration": GeneratedProperty(),
"SecurityPolicy": GeneratedProperty(),
"CertificateArn": GeneratedProperty(),
"Tags": GeneratedProperty(),
"Policy": GeneratedProperty(),
}

DomainName: PassThrough
EndpointConfiguration: Optional[PassThrough]
SecurityPolicy: Optional[PassThrough]
CertificateArn: Optional[PassThrough]
Tags: Optional[PassThrough]
Policy: Optional[PassThrough]


class ApiGatewayBasePathMapping(Resource):
resource_type = "AWS::ApiGateway::BasePathMapping"
property_types = {
Expand All @@ -240,6 +259,16 @@ class ApiGatewayBasePathMapping(Resource):
}


class ApiGatewayBasePathMappingV2(Resource):
resource_type = "AWS::ApiGateway::BasePathMappingV2"
property_types = {
"BasePath": GeneratedProperty(),
"DomainNameArn": GeneratedProperty(),
"RestApiId": GeneratedProperty(),
"Stage": GeneratedProperty(),
}


class ApiGatewayUsagePlan(Resource):
resource_type = "AWS::ApiGateway::UsagePlan"
property_types = {
Expand Down
2 changes: 2 additions & 0 deletions samtranslator/model/sam_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
ApiGatewayApiKey,
ApiGatewayDeployment,
ApiGatewayDomainName,
ApiGatewayDomainNameV2,
ApiGatewayStage,
ApiGatewayUsagePlan,
ApiGatewayUsagePlanKey,
Expand Down Expand Up @@ -1310,6 +1311,7 @@ class SamApi(SamResourceMacro):
"Stage": ApiGatewayStage.resource_type,
"Deployment": ApiGatewayDeployment.resource_type,
"DomainName": ApiGatewayDomainName.resource_type,
"DomainNameV2": ApiGatewayDomainNameV2.resource_type,
"UsagePlan": ApiGatewayUsagePlan.resource_type,
"UsagePlanKey": ApiGatewayUsagePlanKey.resource_type,
"ApiKey": ApiGatewayApiKey.resource_type,
Expand Down
6 changes: 5 additions & 1 deletion samtranslator/schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -277250,7 +277250,8 @@
{
"enum": [
"REGIONAL",
"EDGE"
"EDGE",
"PRIVATE"
],
"type": "string"
}
Expand All @@ -277273,6 +277274,9 @@
"title": "OwnershipVerificationCertificateArn",
"type": "string"
},
"Policy": {
"$ref": "#/definitions/PassThroughProp"
},
"Route53": {
"allOf": [
{
Expand Down
Loading