@@ -33,6 +33,21 @@ class SRASTS:
33
33
log_level : str = os .environ .get ("LOG_LEVEL" , "INFO" )
34
34
LOGGER .setLevel (log_level )
35
35
36
+ def _get_partition_for_region (self , region_name : str ) -> str :
37
+ """Get AWS partition for a given region.
38
+
39
+ Args:
40
+ region_name (str): AWS region name
41
+
42
+ Returns:
43
+ str: AWS partition name (aws, aws-cn, aws-us-gov)
44
+ """
45
+ if region_name .startswith ('us-gov-' ):
46
+ return 'aws-us-gov'
47
+ elif region_name .startswith ('cn-' ):
48
+ return 'aws-cn'
49
+ return 'aws'
50
+
36
51
def __init__ (self , profile : str = "default" ) -> None :
37
52
"""Initialize class object.
38
53
@@ -56,14 +71,14 @@ def __init__(self, profile: str = "default") -> None:
56
71
self .STS_CLIENT = self .MANAGEMENT_ACCOUNT_SESSION .client ("sts" )
57
72
self .HOME_REGION = self .MANAGEMENT_ACCOUNT_SESSION .region_name
58
73
self .LOGGER .info (f"STS detected home region: { self .HOME_REGION } " )
59
- self .PARTITION = self .MANAGEMENT_ACCOUNT_SESSION . get_partition_for_region (self .HOME_REGION )
74
+ self .PARTITION = self ._get_partition_for_region (self .HOME_REGION )
60
75
except botocore .exceptions .ClientError as error :
61
76
if error .response ["Error" ]["Code" ] == "ExpiredToken" :
62
77
self .LOGGER .info ("Token has expired, please re-run with proper credentials set." )
63
78
self .MANAGEMENT_ACCOUNT_SESSION = boto3 .Session ()
64
79
self .STS_CLIENT = self .MANAGEMENT_ACCOUNT_SESSION .client ("sts" )
65
80
self .HOME_REGION = self .MANAGEMENT_ACCOUNT_SESSION .region_name
66
- self .PARTITION = self .MANAGEMENT_ACCOUNT_SESSION . get_partition_for_region (self .HOME_REGION )
81
+ self .PARTITION = self ._get_partition_for_region (self .HOME_REGION )
67
82
68
83
else :
69
84
self .LOGGER .info (f"Error: { error } " )
0 commit comments