Skip to content

Commit cb2560b

Browse files
committed
get_partition_for_region mypy error
1 parent ad5629c commit cb2560b

File tree

1 file changed

+17
-2
lines changed
  • aws_sra_examples/solutions/genai/bedrock_org/lambda/src

1 file changed

+17
-2
lines changed

aws_sra_examples/solutions/genai/bedrock_org/lambda/src/sra_sts.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ class SRASTS:
3333
log_level: str = os.environ.get("LOG_LEVEL", "INFO")
3434
LOGGER.setLevel(log_level)
3535

36+
def _get_partition_for_region(self, region_name: str) -> str:
37+
"""Get AWS partition for a given region.
38+
39+
Args:
40+
region_name (str): AWS region name
41+
42+
Returns:
43+
str: AWS partition name (aws, aws-cn, aws-us-gov)
44+
"""
45+
if region_name.startswith('us-gov-'):
46+
return 'aws-us-gov'
47+
elif region_name.startswith('cn-'):
48+
return 'aws-cn'
49+
return 'aws'
50+
3651
def __init__(self, profile: str = "default") -> None:
3752
"""Initialize class object.
3853
@@ -56,14 +71,14 @@ def __init__(self, profile: str = "default") -> None:
5671
self.STS_CLIENT = self.MANAGEMENT_ACCOUNT_SESSION.client("sts")
5772
self.HOME_REGION = self.MANAGEMENT_ACCOUNT_SESSION.region_name
5873
self.LOGGER.info(f"STS detected home region: {self.HOME_REGION}")
59-
self.PARTITION = self.MANAGEMENT_ACCOUNT_SESSION.get_partition_for_region(self.HOME_REGION)
74+
self.PARTITION = self._get_partition_for_region(self.HOME_REGION)
6075
except botocore.exceptions.ClientError as error:
6176
if error.response["Error"]["Code"] == "ExpiredToken":
6277
self.LOGGER.info("Token has expired, please re-run with proper credentials set.")
6378
self.MANAGEMENT_ACCOUNT_SESSION = boto3.Session()
6479
self.STS_CLIENT = self.MANAGEMENT_ACCOUNT_SESSION.client("sts")
6580
self.HOME_REGION = self.MANAGEMENT_ACCOUNT_SESSION.region_name
66-
self.PARTITION = self.MANAGEMENT_ACCOUNT_SESSION.get_partition_for_region(self.HOME_REGION)
81+
self.PARTITION = self._get_partition_for_region(self.HOME_REGION)
6782

6883
else:
6984
self.LOGGER.info(f"Error: {error}")

0 commit comments

Comments
 (0)