Skip to content

Commit 61bd390

Browse files
committed
feat(auth): Added support for Custom Cluster Names (CNAME) for Amazon Redshift Serverless
1 parent 02c1b52 commit 61bd390

File tree

5 files changed

+134
-2
lines changed

5 files changed

+134
-2
lines changed

redshift_connector/iam_helper.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def set_iam_credentials(info: RedshiftProperty) -> None:
223223
if not isinstance(provider, INativePlugin):
224224
# If the Redshift instance has been identified as using a custom domain name, the hostname must
225225
# be determined using the redshift client from boto3 API
226-
if info.is_cname is True:
226+
if info.is_cname is True and not info.is_serverless:
227227
IamHelper.set_cluster_identifier(provider, info)
228228

229229
# Redshift database credentials will be determined using the redshift client from boto3 API
@@ -442,9 +442,18 @@ def set_cluster_credentials(
442442
if get_creds_api_version == IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1:
443443
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#
444444
get_cred_args: typing.Dict[str, str] = {"dbName": info.db_name}
445+
# if a connection parameter for serverless workgroup is provided it will
446+
# be preferred over providing the CustomDomainName. The reason for this
447+
# is backwards compatibility with the following cases:
448+
# 0/ Serverless with NLB
449+
# 1/ Serverless with Custom Domain Name
450+
# Providing the CustomDomainName parameter to getCredentials will lead to
451+
# failure if the custom domain name is not registered with Redshift. Hence,
452+
# the ordering of these conditions is important.
445453
if info.serverless_work_group:
446454
get_cred_args["workgroupName"] = info.serverless_work_group
447-
455+
elif info.is_cname:
456+
get_cred_args["customDomainName"] = info.host
448457
_logger.debug("Calling get_credentials with parameters %s", get_cred_args)
449458
cred = typing.cast(
450459
typing.Dict[str, typing.Union[str, datetime.datetime]],

redshift_connector/redshift_property.py

+4
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def set_region_from_endpoint_lookup(self: "RedshiftProperty") -> None:
280280
ec2_instance_host: str = host_response[0]
281281
_logger.debug("underlying ec2 instance host %s", ec2_instance_host)
282282
ec2_region: str = ec2_instance_host.split(".")[1]
283+
284+
# https://docs.aws.amazon.com/vpc/latest/userguide/vpc-dns.html#vpc-dns-hostnames
285+
if ec2_region == "compute-1":
286+
ec2_region = "us-east-1"
283287
self.put(key="region", value=ec2_region)
284288
except:
285289
msg: str = "Unable to automatically determine AWS region from host {} port {}. Please check host and port connection parameters are correct.".format(

test/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def provisioned_cname_db_kwargs() -> typing.Dict[str, str]:
118118
return db_connect
119119

120120

121+
@pytest.fixture(scope="class")
122+
def serverless_cname_db_kwargs() -> typing.Dict[str, typing.Union[str, bool]]:
123+
db_connect = {
124+
"database": conf.get("redshift-serverless-cname", "database", fallback="mock_database"),
125+
"host": conf.get("redshift-serverless-cname", "host", fallback="cname.mytest.com"),
126+
"db_user": conf.get("redshift-serverless-cname", "db_user", fallback="mock_user"),
127+
"password": conf.get("redshift-serverless-cname", "password", fallback="mock_password"),
128+
"is_serverless": conf.getboolean("redshift-serverless-cname", "is_serverless", fallback="mockboolean"),
129+
}
130+
131+
return db_connect
132+
133+
121134
@pytest.fixture(scope="class")
122135
def okta_idp() -> typing.Dict[str, typing.Union[str, bool, int]]:
123136
db_connect = {

test/manual/test_redshift_custom_domain.py

+44
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,47 @@ def test_nlb_connect() -> None:
7676
}
7777
with redshift_connector.connect(**args): # type: ignore
7878
pass
79+
80+
81+
@pytest.mark.skip(reason="manual")
82+
@pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
83+
def test_serverless_iam_cname_connect(sslmode, serverless_cname_db_kwargs):
84+
serverless_cname_db_kwargs["iam"] = True
85+
serverless_cname_db_kwargs["profile"] = "default"
86+
serverless_cname_db_kwargs["auto_create"] = True
87+
serverless_cname_db_kwargs["ssl"] = True
88+
serverless_cname_db_kwargs["sslmode"] = sslmode.value
89+
90+
with redshift_connector.connect(**serverless_cname_db_kwargs) as conn:
91+
with conn.cursor() as cursor:
92+
cursor.execute("select current_user")
93+
print(cursor.fetchone())
94+
95+
96+
@pytest.mark.skip(reason="manual")
97+
@pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
98+
def test_serverless_cname_connect(sslmode, serverless_cname_db_kwargs):
99+
# this test requires aws default profile contains valid credentials that provide permissions for
100+
# redshift-serverless:GetCredentials ( Only called from this test method)
101+
import boto3
102+
103+
profile = "default"
104+
client = boto3.client(
105+
service_name="redshift-serverless",
106+
region_name="us-east-1",
107+
)
108+
# fetch cluster credentials and pass them as driver connect parameters
109+
response = client.get_credentials(
110+
customDomainName=serverless_cname_db_kwargs["host"], dbName=serverless_cname_db_kwargs["database"]
111+
)
112+
113+
serverless_cname_db_kwargs["sslmode"] = sslmode.value
114+
serverless_cname_db_kwargs["ssl"] = True
115+
serverless_cname_db_kwargs["user"] = response["dbUser"]
116+
serverless_cname_db_kwargs["password"] = response["dbPassword"]
117+
serverless_cname_db_kwargs["profile"] = profile
118+
119+
with redshift_connector.connect(**serverless_cname_db_kwargs) as conn:
120+
with conn.cursor() as cursor:
121+
cursor.execute("select current_user")
122+
print(cursor.fetchone())

test/unit/test_iam_helper.py

+62
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,68 @@ def test_set_cluster_credentials_uses_custom_domain_name_if_custom_domain_name_c
520520
)
521521

522522

523+
@mock.patch("boto3.client.get_credentials")
524+
@mock.patch("boto3.client")
525+
def test_set_cluster_credentials_uses_custom_domain_name_if_custom_domain_name_serverless(
526+
mock_boto_client, mock_get_cluster_credentials
527+
):
528+
mock_cred_provider = MagicMock()
529+
mock_cred_holder = MagicMock()
530+
mock_cred_provider.get_credentials.return_value = mock_cred_holder
531+
mock_cred_provider.get_cache_key.return_value = "mocked"
532+
mock_cred_holder.has_associated_session = False
533+
534+
rp: RedshiftProperty = make_redshift_property()
535+
rp.put("cluster_identifier", None)
536+
rp.put("host", "mycustom.domain.name")
537+
rp.put("is_serverless", True)
538+
rp.put("is_cname", True)
539+
540+
IamHelper.credentials_cache.clear()
541+
542+
IamHelper.set_cluster_credentials(mock_cred_provider, rp)
543+
544+
assert mock_boto_client.called is True
545+
mock_boto_client.assert_has_calls(
546+
[
547+
call().get_credentials(
548+
customDomainName=rp.host,
549+
dbName=rp.db_name,
550+
)
551+
]
552+
)
553+
554+
555+
@mock.patch("boto3.client.get_credentials")
556+
@mock.patch("boto3.client")
557+
def test_set_cluster_credentials_uses_workgroup_if_nlb_serverless(mock_boto_client, mock_get_cluster_credentials):
558+
mock_cred_provider = MagicMock()
559+
mock_cred_holder = MagicMock()
560+
mock_cred_provider.get_credentials.return_value = mock_cred_holder
561+
mock_cred_provider.get_cache_key.return_value = "mocked"
562+
mock_cred_holder.has_associated_session = False
563+
564+
rp: RedshiftProperty = make_redshift_property()
565+
rp.put("cluster_identifier", None)
566+
rp.put("host", "mycustom.domain.name")
567+
rp.put("is_serverless", True)
568+
rp.put("serverless_work_group", "xyz")
569+
rp.put("serverless_acct_id", "012345678901")
570+
IamHelper.credentials_cache.clear()
571+
572+
IamHelper.set_cluster_credentials(mock_cred_provider, rp)
573+
574+
assert mock_boto_client.called is True
575+
mock_boto_client.assert_has_calls(
576+
[
577+
call().get_credentials(
578+
workgroupName=rp.serverless_work_group,
579+
dbName=rp.db_name,
580+
)
581+
]
582+
)
583+
584+
523585
@mock.patch("boto3.client.get_cluster_credentials")
524586
@mock.patch("boto3.client")
525587
def test_set_cluster_credentials_caches_credentials(mock_boto_client, mock_get_cluster_credentials) -> None:

0 commit comments

Comments
 (0)