Skip to content

Commit 052caac

Browse files
committed
Added support for security token authentication
1 parent bb86b6e commit 052caac

File tree

2 files changed

+229
-3
lines changed

2 files changed

+229
-3
lines changed

ads/common/auth.py

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AuthType(str, metaclass=ExtendedEnumMeta):
2222
API_KEY = "api_key"
2323
RESOURCE_PRINCIPAL = "resource_principal"
2424
INSTANCE_PRINCIPAL = "instance_principal"
25+
SECURITY_TOKEN = "security_token"
2526

2627

2728
class SingletonMeta(type):
@@ -140,6 +141,14 @@ def set_auth(
140141
141142
>>> ads.set_auth("instance_principal") # Set instance principal authentication
142143
144+
>>> ads.set_auth("security_token") # Set security token authentication
145+
146+
>>> config = dict(
147+
... key_file=~/.oci/sessions/DEFAULT/oci_api_key.pem
148+
... security_token_file=~/.oci/sessions/DEFAULT/token
149+
... )
150+
>>> ads.set_auth("security_token", config=config) # Set security token authentication from provided config
151+
143152
>>> singer = oci.signer.Signer(
144153
... user=ocid1.user.oc1..<unique_ID>,
145154
... fingerprint=<fingerprint>,
@@ -274,6 +283,50 @@ def resource_principal(
274283
return signer_generator(signer_args).create_signer()
275284

276285

286+
def security_token(
287+
oci_config: Union[str, Dict] = os.path.expanduser(DEFAULT_LOCATION),
288+
profile: str = DEFAULT_PROFILE,
289+
client_kwargs: Dict = None,
290+
) -> Dict:
291+
"""
292+
Prepares authentication and extra arguments necessary for creating clients for different OCI services using Security Token.
293+
294+
Parameters
295+
----------
296+
oci_config: Optional[Union[str, Dict]], default is ~/.oci/config
297+
OCI authentication config file location or a dictionary with config attributes.
298+
profile: Optional[str], is DEFAULT_PROFILE, which is 'DEFAULT'
299+
Profile name to select from the config file.
300+
client_kwargs: Optional[Dict], default None
301+
kwargs that are required to instantiate the Client if we need to override the defaults.
302+
303+
Returns
304+
-------
305+
dict
306+
Contains keys - config, signer and client_kwargs.
307+
308+
- The config contains the config loaded from the configuration loaded from `oci_config`.
309+
- The signer contains the signer object created from the security token.
310+
- client_kwargs contains the `client_kwargs` that was passed in as input parameter.
311+
312+
Examples
313+
--------
314+
>>> from ads.common import oci_client as oc
315+
>>> auth = ads.auth.security_token(oci_config="/home/datascience/.oci/config", profile="TEST", client_kwargs={"timeout": 6000})
316+
>>> oc.OCIClientFactory(**auth).object_storage # Creates Object storage client with timeout set to 6000 using Security Token authentication
317+
"""
318+
signer_args = dict(
319+
oci_config=oci_config if isinstance(oci_config, Dict) else {},
320+
oci_config_location=oci_config
321+
if isinstance(oci_config, str)
322+
else os.path.expanduser(DEFAULT_LOCATION),
323+
oci_key_profile=profile,
324+
client_kwargs=client_kwargs,
325+
)
326+
signer_generator = AuthFactory().signerGenerator(AuthType.SECURITY_TOKEN)
327+
return signer_generator(signer_args).create_signer()
328+
329+
277330
def create_signer(
278331
auth_type: Optional[str] = AuthType.API_KEY,
279332
oci_config_location: Optional[str] = DEFAULT_LOCATION,
@@ -346,6 +399,11 @@ def create_signer(
346399
>>> signer_callable = oci.auth.signers.InstancePrincipalsSecurityTokenSigner
347400
>>> signer_kwargs = dict(log_requests=True) # will log the request url and response data when retrieving
348401
>>> auth = ads.auth.create_signer(signer_callable=signer_callable, signer_kwargs=signer_kwargs) # instance principals authentication dictionary created based on callable with kwargs parameters
402+
>>> config = dict(
403+
... key_file=~/.oci/sessions/DEFAULT/oci_api_key.pem
404+
... security_token_file=~/.oci/sessions/DEFAULT/token
405+
... )
406+
>>> auth = ads.auth.create_signer(auth_type="security_token", config=config) # security token authentication created based on provided config
349407
"""
350408
if signer or signer_callable:
351409
configuration = ads.telemetry.update_oci_client_config()
@@ -365,8 +423,6 @@ def create_signer(
365423
oci_config=config,
366424
client_kwargs=client_kwargs,
367425
)
368-
if config:
369-
auth_type = AuthType.API_KEY
370426

371427
signer_generator = AuthFactory().signerGenerator(auth_type)
372428

@@ -678,6 +734,102 @@ def create_signer(self) -> Dict:
678734
return signer_dict
679735

680736

737+
class SecurityToken(AuthSignerGenerator):
738+
def __init__(self, args: Optional[Dict] = None):
739+
"""
740+
Signer created based on args provided. If not provided current values of according arguments
741+
will be used from current global state from AuthState class.
742+
743+
Parameters
744+
----------
745+
args: dict
746+
args that are required to create Security Token signer. Contains keys: oci_config,
747+
oci_config_location, oci_key_profile, client_kwargs.
748+
749+
- oci_config is a configuration dict that can be used to create clients
750+
- oci_config_location - path to config file
751+
- oci_key_profile - the profile to load from config file
752+
- client_kwargs - optional parameters for OCI client creation in next steps
753+
"""
754+
self.oci_config = args.get("oci_config")
755+
self.oci_config_location = args.get("oci_config_location")
756+
self.oci_key_profile = args.get("oci_key_profile")
757+
self.client_kwargs = args.get("client_kwargs")
758+
759+
def create_signer(self) -> Dict:
760+
"""
761+
Creates security token configuration and signer with extra arguments necessary for creating clients.
762+
Signer constructed from the `oci_config` provided. If not 'oci_config', configuration will be
763+
constructed from 'oci_config_location' and 'oci_key_profile' in place.
764+
765+
Returns
766+
-------
767+
dict
768+
Contains keys - config, signer and client_kwargs.
769+
770+
- config contains the configuration information
771+
- signer contains the signer object created. It is instantiated from signer_callable, or
772+
signer provided in args used, or instantiated in place
773+
- client_kwargs contains the `client_kwargs` that was passed in as input parameter
774+
775+
Examples
776+
--------
777+
>>> signer_args = dict(
778+
... client_kwargs=client_kwargs
779+
... )
780+
>>> signer_generator = AuthFactory().signerGenerator(AuthType.SECURITY_TOKEN)
781+
>>> signer_generator(signer_args).create_signer()
782+
"""
783+
if self.oci_config:
784+
configuration = ads.telemetry.update_oci_client_config(self.oci_config)
785+
else:
786+
configuration = ads.telemetry.update_oci_client_config(
787+
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
788+
)
789+
790+
logger.info(f"Using 'security_token' authentication.")
791+
792+
if "security_token_file" not in configuration and "security_token_content" not in configuration:
793+
raise ValueError(
794+
"Parameter `security_token_file` or `security_token_content` must be provided for using `security_token` authentication."
795+
)
796+
797+
if "key_file" not in configuration and "key_content" not in configuration:
798+
raise ValueError(
799+
"Parameter `key_file` or `key_content` must be provided for using `security_token` authentication."
800+
)
801+
802+
if "security_token_content" not in configuration and not self.oci_config:
803+
os.system(f'oci session refresh --profile {self.oci_key_profile or DEFAULT_PROFILE}')
804+
805+
return {
806+
"config": configuration,
807+
"signer": oci.auth.signers.SecurityTokenSigner(
808+
token=(
809+
configuration.get("security_token_content", None)
810+
or self._read_security_token_file(configuration.get("security_token_file"))
811+
),
812+
private_key=(
813+
oci.signer.load_private_key(configuration.get("key_content"))
814+
if configuration.get("key_content")
815+
else oci.signer.load_private_key_from_file(configuration.get("key_file"))
816+
),
817+
generic_headers=configuration.get("generic_headers"),
818+
body_headers=configuration.get("body_headers")
819+
),
820+
"client_kwargs": self.client_kwargs,
821+
}
822+
823+
def _read_security_token_file(self, security_token_file: str) -> str:
824+
try:
825+
token = None
826+
with open(security_token_file, 'r') as f:
827+
token = f.read()
828+
return token
829+
except:
830+
raise
831+
832+
681833
class AuthFactory:
682834
"""
683835
AuthFactory class which contains list of registered signers and alllows to register new signers.
@@ -687,12 +839,14 @@ class AuthFactory:
687839
* APIKey
688840
* ResourcePrincipal
689841
* InstancePrincipal
842+
* SecurityToken
690843
"""
691844

692845
classes = {
693846
AuthType.API_KEY: APIKey,
694847
AuthType.RESOURCE_PRINCIPAL: ResourcePrincipal,
695848
AuthType.INSTANCE_PRINCIPAL: InstancePrincipal,
849+
AuthType.SECURITY_TOKEN: SecurityToken,
696850
}
697851

698852
@classmethod
@@ -726,7 +880,7 @@ def signerGenerator(self, iam_type: Optional[str] = "api_key"):
726880
727881
Returns
728882
-------
729-
:class:`APIKey` or :class:`ResourcePrincipal` or :class:`InstancePrincipal`
883+
:class:`APIKey` or :class:`ResourcePrincipal` or :class:`InstancePrincipal` or :class:`SecurityToken`
730884
returns one of classes, which implements creation of signer of specified type
731885
732886
Raises

tests/unitary/default_setup/auth/test_auth.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from oci.auth.signers.ephemeral_resource_principals_signer import (
1212
EphemeralResourcePrincipalSigner,
1313
)
14+
from oci.auth.signers.security_token_signer import SecurityTokenSigner
1415
import ads
1516
from ads import set_auth
1617
from ads.common.utils import (
@@ -20,6 +21,7 @@
2021
from ads.common.auth import (
2122
api_keys,
2223
resource_principal,
24+
security_token,
2325
create_signer,
2426
default_signer,
2527
get_signer,
@@ -122,6 +124,76 @@ def test_set_auth_with_key_content(self, mock_load_private_key, mock_validate_co
122124
assert signer["signer"] != None
123125
set_auth()
124126

127+
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
128+
@mock.patch("oci.signer.load_private_key")
129+
@mock.patch("oci.signer.load_private_key_from_file")
130+
@mock.patch("ads.common.auth.SecurityToken._read_security_token_file")
131+
def test_security_token(
132+
self,
133+
mock_read_security_token_file,
134+
mock_load_private_key_from_file,
135+
mock_load_private_key,
136+
mock_security_token_signer
137+
):
138+
config = {
139+
"fingerprint": "test_fingerprint",
140+
"tenancy": "test_tenancy",
141+
"region": "us-ashburn-1",
142+
"generic_headers": [1,2,3],
143+
"body_headers": [4,5,6]
144+
}
145+
146+
with pytest.raises(
147+
ValueError,
148+
match="Parameter `security_token_file` or `security_token_content` must be provided for using `security_token` authentication."
149+
):
150+
signer = security_token(
151+
oci_config=config,
152+
client_kwargs={"test_client_key":"test_client_value"}
153+
)
154+
155+
config["security_token_file"] = "test_security_token"
156+
with pytest.raises(
157+
ValueError,
158+
match="Parameter `key_file` or `key_content` must be provided for using `security_token` authentication."
159+
):
160+
signer = security_token(
161+
oci_config=config,
162+
client_kwargs={"test_client_key":"test_client_value"}
163+
)
164+
165+
config["key_file"] = "test_key_file"
166+
mock_security_token_signer.return_value = None
167+
signer = security_token(
168+
oci_config=config,
169+
client_kwargs={"test_client_key":"test_client_value"}
170+
)
171+
172+
mock_read_security_token_file.assert_called_with("test_security_token")
173+
mock_load_private_key_from_file.assert_called_with("test_key_file")
174+
assert signer["client_kwargs"] == {"test_client_key": "test_client_value"}
175+
assert "additional_user_agent" in signer["config"]
176+
assert signer["config"]["fingerprint"] == "test_fingerprint"
177+
assert signer["config"]["tenancy"] == "test_tenancy"
178+
assert signer["config"]["region"] == "us-ashburn-1"
179+
assert signer["config"]["security_token_file"] == "test_security_token"
180+
assert signer["config"]["key_file"] == "test_key_file"
181+
assert isinstance(signer["signer"], SecurityTokenSigner)
182+
183+
config = {
184+
"fingerprint": "test_fingerprint",
185+
"tenancy": "test_tenancy",
186+
"region": "us-ashburn-1",
187+
"security_token_content": "test_security_token_content",
188+
"key_content": "test_key_content"
189+
}
190+
signer = security_token(
191+
oci_config=config,
192+
client_kwargs={"test_client_key":"test_client_value"}
193+
)
194+
mock_load_private_key.assert_called_with("test_key_content")
195+
196+
125197
class TestOCIMixin(TestCase):
126198
def tearDown(self) -> None:
127199
with mock.patch("os.path.exists"):

0 commit comments

Comments
 (0)