Skip to content

Commit 3be88d7

Browse files
committed
Updated pr.
1 parent f050c10 commit 3be88d7

File tree

2 files changed

+160
-28
lines changed

2 files changed

+160
-28
lines changed

ads/common/auth.py

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77
import copy
88
from datetime import datetime
9+
import json
910
import os
1011
from dataclasses import dataclass
1112
import time
1213
from typing import Any, Callable, Dict, Optional, Union
1314

15+
import requests
16+
1417
import ads.telemetry
1518
import oci
19+
from oci_cli import cli_util
1620
from ads.common import logger
1721
from ads.common.decorator.deprecate import deprecated
1822
from ads.common.extended_enum import ExtendedEnumMeta
@@ -37,7 +41,7 @@
3741
SECURITY_TOKEN_LEFT_TIME = 600
3842

3943

40-
class TokenExpiredError(Exception): # pragma: no cover
44+
class SecurityTokenError(Exception): # pragma: no cover
4145
pass
4246

4347

@@ -825,52 +829,106 @@ def create_signer(self) -> Dict:
825829
raise ValueError(
826830
f"Parameter `{parameter}` must be provided for using `security_token` authentication."
827831
)
828-
832+
833+
self._validate_and_refresh_token(configuration)
834+
829835
return {
830836
"config": configuration,
831837
"signer": oci.auth.signers.SecurityTokenSigner(
832-
token=self._validate_and_refresh_token(configuration.get("security_token_file")),
833-
private_key=oci.signer.load_private_key_from_file(configuration.get("key_file")),
838+
token=self._read_security_token_file(configuration.get("security_token_file")),
839+
private_key=oci.signer.load_private_key_from_file(
840+
configuration.get("key_file"), configuration.get("pass_phrase")
841+
),
834842
generic_headers=configuration.get("generic_headers", SECURITY_TOKEN_GENERIC_HEADERS),
835843
body_headers=configuration.get("body_headers", SECURITY_TOKEN_BODY_HEADERS)
836844
),
837845
"client_kwargs": self.client_kwargs,
838846
}
839847

840-
def _validate_and_refresh_token(self, security_token_file: str) -> str:
848+
def _validate_and_refresh_token(self, configuration: Dict[str, Any]):
841849
"""Validates and refreshes security token.
842850
843851
Parameters
844852
----------
845-
security_token_file: str
846-
Path to security token file.
847-
848-
Returns
849-
-------
850-
str:
851-
Security token string.
853+
configuration: Dict
854+
Security token configuration.
852855
"""
853-
security_token = self._read_security_token_file(security_token_file)
856+
security_token = self._read_security_token_file(configuration.get("security_token_file"))
854857
security_token_container = oci.auth.security_token_container.SecurityTokenContainer(
855858
session_key_supplier=None,
856859
security_token=security_token
857860
)
858-
861+
859862
if not security_token_container.valid():
860-
raise TokenExpiredError(
863+
raise SecurityTokenError(
861864
"Security token has expired. Call `oci session authenticate` to generate new session."
862865
)
863866

864867
time_now = int(time.time())
865868
time_expired = security_token_container.get_jwt()["exp"]
866869
if time_now - time_expired < SECURITY_TOKEN_LEFT_TIME:
867-
if self.oci_config_location == DEFAULT_LOCATION and self.oci_key_profile:
868-
os.system(f'oci session refresh --profile {self.oci_key_profile}')
869-
security_token = self._read_security_token_file(security_token_file)
870+
try:
871+
self._refresh_security_token(configuration)
872+
except Exception as ex:
873+
logger.info("Failed to refresh security token. Error: {}".format(ex))
870874

871875
date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
872876
logger.info(f"Session is valid until {date_time}.")
873-
return security_token
877+
878+
def _refresh_security_token(self, configuration: Dict[str, Any]):
879+
"""Refreshes security token. The logic is mainly taken reference from:
880+
https://github.com/oracle/oci-cli/blob/9a0978344950d7b7c24a688892f24968dce20ad3/src/oci_cli/cli_session.py#L152
881+
882+
Parameters
883+
----------
884+
configuration: Dict
885+
Security token configuration.
886+
"""
887+
expanded_security_token_location = os.path.expanduser(
888+
configuration.get("security_token_file")
889+
)
890+
891+
with open(expanded_security_token_location, 'r') as security_token_file:
892+
token = security_token_file.read()
893+
894+
try:
895+
private_key = oci.signer.load_private_key_from_file(
896+
configuration.get("key_file"), configuration.get("pass_phrase")
897+
)
898+
except:
899+
raise
900+
auth = oci.auth.signers.SecurityTokenSigner(token, private_key)
901+
902+
refresh_url = "{endpoint}/v1/authentication/refresh".format(
903+
endpoint=oci.regions.endpoint_for("auth", configuration.get("region"))
904+
)
905+
logger.info(f"Attempting to refresh token from {refresh_url}.")
906+
907+
response = requests.post(
908+
refresh_url,
909+
headers={
910+
'content-type': 'application/json'
911+
},
912+
data=json.dumps({
913+
'currentToken': token
914+
}),
915+
auth=auth
916+
)
917+
918+
if response.status_code == 200:
919+
refreshed_token = json.loads(response.content.decode('UTF-8'))['token']
920+
with open(expanded_security_token_location, 'w') as security_token_file:
921+
security_token_file.write(refreshed_token)
922+
cli_util.apply_user_only_access_permissions(expanded_security_token_location)
923+
logger.info("Successfully refreshed token")
924+
elif response.status_code == 401:
925+
raise SecurityTokenError(
926+
"Security token has expired. Call `oci session authenticate` to generate new session."
927+
)
928+
else:
929+
raise SecurityTokenError(
930+
"Failed to refresh sesison. Error: {}".format(str(response.content.decode('UTF-8')))
931+
)
874932

875933
def _read_security_token_file(self, security_token_file: str) -> str:
876934
"""Reads security token from file.

tests/unitary/default_setup/auth/test_auth.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import os
7+
from mock import MagicMock
78
import pytest
89
from unittest import TestCase, mock
910

@@ -21,7 +22,7 @@
2122
)
2223
from ads.common.auth import (
2324
SecurityToken,
24-
TokenExpiredError,
25+
SecurityTokenError,
2526
api_keys,
2627
resource_principal,
2728
security_token,
@@ -539,10 +540,12 @@ class TestSecurityToken(TestCase):
539540

540541
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
541542
@mock.patch("oci.signer.load_private_key_from_file")
543+
@mock.patch("ads.common.auth.SecurityToken._read_security_token_file")
542544
@mock.patch("ads.common.auth.SecurityToken._validate_and_refresh_token")
543545
def test_security_token(
544546
self,
545547
mock_validate_and_refresh_token,
548+
mock_read_security_token_file,
546549
mock_load_private_key_from_file,
547550
mock_security_token_signer
548551
):
@@ -571,8 +574,9 @@ def test_security_token(
571574
client_kwargs={"test_client_key":"test_client_value"}
572575
)
573576

574-
mock_validate_and_refresh_token.assert_called_with("test_security_token")
575-
mock_load_private_key_from_file.assert_called_with("test_key_file")
577+
mock_validate_and_refresh_token.assert_called_with(config)
578+
mock_read_security_token_file.assert_called_with("test_security_token")
579+
mock_load_private_key_from_file.assert_called_with("test_key_file", None)
576580
assert signer["client_kwargs"] == {"test_client_key": "test_client_value"}
577581
assert "additional_user_agent" in signer["config"]
578582
assert signer["config"]["fingerprint"] == "test_fingerprint"
@@ -582,7 +586,7 @@ def test_security_token(
582586
assert signer["config"]["key_file"] == "test_key_file"
583587
assert isinstance(signer["signer"], SecurityTokenSigner)
584588

585-
@mock.patch("os.system")
589+
@mock.patch("ads.common.auth.SecurityToken._refresh_security_token")
586590
@mock.patch("oci.auth.security_token_container.SecurityTokenContainer.get_jwt")
587591
@mock.patch("time.time")
588592
@mock.patch("oci.auth.security_token_container.SecurityTokenContainer.valid")
@@ -595,7 +599,7 @@ def test_validate_and_refresh_token(
595599
mock_valid,
596600
mock_time,
597601
mock_get_jwt,
598-
mock_system
602+
mock_refresh_security_token
599603
):
600604
security_token = SecurityToken(
601605
args={
@@ -606,24 +610,94 @@ def test_validate_and_refresh_token(
606610
mock_security_token_container.return_value = None
607611

608612
mock_valid.return_value = False
613+
configuration = {
614+
"fingerprint": "test_fingerprint",
615+
"tenancy": "test_tenancy",
616+
"region": "us-ashburn-1",
617+
"key_file": "test_key_file",
618+
"security_token_file": "test_security_token",
619+
"generic_headers": [1,2,3],
620+
"body_headers": [4,5,6]
621+
}
609622
with pytest.raises(
610-
TokenExpiredError,
623+
SecurityTokenError,
611624
match="Security token has expired. Call `oci session authenticate` to generate new session."
612625
):
613-
security_token._validate_and_refresh_token("test_security_token")
626+
security_token._validate_and_refresh_token(configuration)
614627

615628

616629
mock_valid.return_value = True
617630
mock_time.return_value = 1
618631
mock_get_jwt.return_value = {"exp" : 1}
619632

620-
security_token._validate_and_refresh_token("test_security_token")
633+
security_token._validate_and_refresh_token(configuration)
621634

622635
mock_read_security_token_file.assert_called_with("test_security_token")
623636
mock_security_token_container.assert_called()
624637
mock_time.assert_called()
625638
mock_get_jwt.assert_called()
626-
mock_system.assert_called_with("oci session refresh --profile test_profile")
639+
mock_refresh_security_token.assert_called_with(configuration)
640+
641+
@mock.patch("oci_cli.cli_util.apply_user_only_access_permissions")
642+
@mock.patch("json.loads")
643+
@mock.patch("requests.post")
644+
@mock.patch("json.dumps")
645+
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
646+
@mock.patch("oci.signer.load_private_key_from_file")
647+
@mock.patch("builtins.open")
648+
def test_refresh_security_token(
649+
self,
650+
mock_open,
651+
mock_load_private_key_from_file,
652+
mock_security_token_signer,
653+
mock_dumps,
654+
mock_post,
655+
mock_loads,
656+
mock_apply_user_only_access_permissions
657+
):
658+
security_token = SecurityToken(args={})
659+
configuration = {
660+
"fingerprint": "test_fingerprint",
661+
"tenancy": "test_tenancy",
662+
"region": "us-ashburn-1",
663+
"key_file": "test_key_file",
664+
"security_token_file": "test_security_token",
665+
"generic_headers": [1,2,3],
666+
"body_headers": [4,5,6]
667+
}
668+
mock_security_token_signer.return_value = None
669+
mock_loads.return_value = {
670+
"token": "test_token"
671+
}
672+
673+
response = MagicMock()
674+
response.status_code = 401
675+
mock_post.return_value = response
676+
with pytest.raises(
677+
SecurityTokenError,
678+
match="Security token has expired. Call `oci session authenticate` to generate new session."
679+
):
680+
security_token._refresh_security_token(configuration)
681+
682+
response.status_code = 500
683+
mock_post.return_value = response
684+
with pytest.raises(
685+
SecurityTokenError,
686+
):
687+
security_token._refresh_security_token(configuration)
688+
689+
response.status_code = 200
690+
response.content = bytes("test_content", encoding='utf8')
691+
mock_post.return_value = response
692+
security_token._refresh_security_token(configuration)
693+
694+
mock_open.assert_called()
695+
mock_load_private_key_from_file.assert_called_with("test_key_file", None)
696+
mock_security_token_signer.assert_called()
697+
mock_dumps.assert_called()
698+
mock_post.assert_called()
699+
mock_loads.assert_called()
700+
mock_apply_user_only_access_permissions.assert_called()
627701

628702
@mock.patch("builtins.open")
629703
@mock.patch("os.path.isfile")

0 commit comments

Comments
 (0)