Skip to content

Commit f050c10

Browse files
committed
Updated pr.
1 parent 23c5859 commit f050c10

File tree

2 files changed

+166
-96
lines changed

2 files changed

+166
-96
lines changed

ads/common/auth.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import copy
8+
from datetime import datetime
89
import os
910
from dataclasses import dataclass
11+
import time
1012
from typing import Any, Callable, Dict, Optional, Union
1113

1214
import ads.telemetry
@@ -32,6 +34,11 @@
3234
"key_file",
3335
"region"
3436
]
37+
SECURITY_TOKEN_LEFT_TIME = 600
38+
39+
40+
class TokenExpiredError(Exception): # pragma: no cover
41+
pass
3542

3643

3744
class AuthType(str, metaclass=ExtendedEnumMeta):
@@ -819,21 +826,65 @@ def create_signer(self) -> Dict:
819826
f"Parameter `{parameter}` must be provided for using `security_token` authentication."
820827
)
821828

822-
if not self.oci_config:
823-
os.system(f'oci session refresh --profile {self.oci_key_profile}')
824-
825829
return {
826830
"config": configuration,
827831
"signer": oci.auth.signers.SecurityTokenSigner(
828-
token=self._read_security_token_file(configuration.get("security_token_file")),
832+
token=self._validate_and_refresh_token(configuration.get("security_token_file")),
829833
private_key=oci.signer.load_private_key_from_file(configuration.get("key_file")),
830834
generic_headers=configuration.get("generic_headers", SECURITY_TOKEN_GENERIC_HEADERS),
831835
body_headers=configuration.get("body_headers", SECURITY_TOKEN_BODY_HEADERS)
832836
),
833837
"client_kwargs": self.client_kwargs,
834838
}
835-
839+
840+
def _validate_and_refresh_token(self, security_token_file: str) -> str:
841+
"""Validates and refreshes security token.
842+
843+
Parameters
844+
----------
845+
security_token_file: str
846+
Path to security token file.
847+
848+
Returns
849+
-------
850+
str:
851+
Security token string.
852+
"""
853+
security_token = self._read_security_token_file(security_token_file)
854+
security_token_container = oci.auth.security_token_container.SecurityTokenContainer(
855+
session_key_supplier=None,
856+
security_token=security_token
857+
)
858+
859+
if not security_token_container.valid():
860+
raise TokenExpiredError(
861+
"Security token has expired. Call `oci session authenticate` to generate new session."
862+
)
863+
864+
time_now = int(time.time())
865+
time_expired = security_token_container.get_jwt()["exp"]
866+
if time_now - time_expired < SECURITY_TOKEN_LEFT_TIME:
867+
if self.oci_config_location == DEFAULT_LOCATION and self.oci_key_profile:
868+
os.system(f'oci session refresh --profile {self.oci_key_profile}')
869+
security_token = self._read_security_token_file(security_token_file)
870+
871+
date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
872+
logger.info(f"Session is valid until {date_time}.")
873+
return security_token
874+
836875
def _read_security_token_file(self, security_token_file: str) -> str:
876+
"""Reads security token from file.
877+
878+
Parameters
879+
----------
880+
security_token_file: str
881+
The path to security token file.
882+
883+
Returns
884+
-------
885+
str:
886+
Security token string.
887+
"""
837888
if not os.path.isfile(security_token_file):
838889
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
839890
try:

tests/unitary/default_setup/auth/test_auth.py

Lines changed: 110 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
EphemeralResourcePrincipalSigner,
1313
)
1414
from oci.auth.signers.security_token_signer import SecurityTokenSigner
15+
from oci.config import DEFAULT_LOCATION
1516
import ads
1617
from ads import set_auth
1718
from ads.common.utils import (
1819
oci_key_profile,
1920
oci_config_location,
2021
)
2122
from ads.common.auth import (
23+
SecurityToken,
24+
TokenExpiredError,
2225
api_keys,
2326
resource_principal,
2427
security_token,
@@ -124,97 +127,6 @@ def test_set_auth_with_key_content(self, mock_load_private_key, mock_validate_co
124127
assert signer["signer"] != None
125128
set_auth()
126129

127-
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
128-
@mock.patch("oci.signer.load_private_key_from_file")
129-
@mock.patch("ads.common.auth.SecurityToken._read_security_token_file")
130-
def test_security_token_from_config(
131-
self,
132-
mock_read_security_token_file,
133-
mock_load_private_key_from_file,
134-
mock_security_token_signer
135-
):
136-
config = {
137-
"fingerprint": "test_fingerprint",
138-
"tenancy": "test_tenancy",
139-
"region": "us-ashburn-1",
140-
"key_file": "test_key_file",
141-
"generic_headers": [1,2,3],
142-
"body_headers": [4,5,6]
143-
}
144-
145-
with pytest.raises(
146-
ValueError,
147-
match="Parameter `security_token_file` must be provided for using `security_token` authentication."
148-
):
149-
signer = security_token(
150-
oci_config=config,
151-
client_kwargs={"test_client_key":"test_client_value"}
152-
)
153-
154-
config["security_token_file"] = "test_security_token"
155-
mock_security_token_signer.return_value = None
156-
signer = security_token(
157-
oci_config=config,
158-
client_kwargs={"test_client_key":"test_client_value"}
159-
)
160-
161-
mock_read_security_token_file.assert_called_with("test_security_token")
162-
mock_load_private_key_from_file.assert_called_with("test_key_file")
163-
assert signer["client_kwargs"] == {"test_client_key": "test_client_value"}
164-
assert "additional_user_agent" in signer["config"]
165-
assert signer["config"]["fingerprint"] == "test_fingerprint"
166-
assert signer["config"]["tenancy"] == "test_tenancy"
167-
assert signer["config"]["region"] == "us-ashburn-1"
168-
assert signer["config"]["security_token_file"] == "test_security_token"
169-
assert signer["config"]["key_file"] == "test_key_file"
170-
assert isinstance(signer["signer"], SecurityTokenSigner)
171-
172-
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
173-
@mock.patch("oci.signer.load_private_key_from_file")
174-
@mock.patch("builtins.open")
175-
@mock.patch("os.path.isfile")
176-
@mock.patch("os.system")
177-
@mock.patch("oci.config.from_file")
178-
def test_security_token_from_file(
179-
self,
180-
mock_from_file,
181-
mock_system,
182-
mock_isfile,
183-
mock_open,
184-
mock_load_private_key_from_file,
185-
mock_security_token_signer
186-
):
187-
mock_from_file.return_value = {
188-
"fingerprint": "test_fingerprint",
189-
"tenancy": "test_tenancy",
190-
"region": "us-ashburn-1",
191-
"key_file": "test_key_file",
192-
"security_token_file": "test_security_token"
193-
}
194-
mock_isfile.return_value = True
195-
mock_security_token_signer.return_value = None
196-
signer = security_token(
197-
oci_config="test_config_location",
198-
profile="test_key_profile",
199-
client_kwargs={"test_client_key":"test_client_value"}
200-
)
201-
202-
mock_from_file.assert_called_with("test_config_location", "test_key_profile")
203-
mock_system.assert_called_with("oci session refresh --profile test_key_profile")
204-
mock_isfile.assert_called_with("test_security_token")
205-
mock_open.assert_called()
206-
mock_load_private_key_from_file.assert_called_with("test_key_file")
207-
mock_security_token_signer.assert_called()
208-
209-
assert signer["client_kwargs"] == {"test_client_key": "test_client_value"}
210-
assert "additional_user_agent" in signer["config"]
211-
assert signer["config"]["fingerprint"] == "test_fingerprint"
212-
assert signer["config"]["tenancy"] == "test_tenancy"
213-
assert signer["config"]["region"] == "us-ashburn-1"
214-
assert signer["config"]["security_token_file"] == "test_security_token"
215-
assert signer["config"]["key_file"] == "test_key_file"
216-
assert isinstance(signer["signer"], SecurityTokenSigner)
217-
218130

219131
class TestOCIMixin(TestCase):
220132
def tearDown(self) -> None:
@@ -621,3 +533,110 @@ def test_with_set_auth_returns_error(self):
621533
with pytest.raises(ValueError):
622534
with AuthContext(auth="not_correct_auth_type"):
623535
pass
536+
537+
538+
class TestSecurityToken(TestCase):
539+
540+
@mock.patch("oci.auth.signers.SecurityTokenSigner.__init__")
541+
@mock.patch("oci.signer.load_private_key_from_file")
542+
@mock.patch("ads.common.auth.SecurityToken._validate_and_refresh_token")
543+
def test_security_token(
544+
self,
545+
mock_validate_and_refresh_token,
546+
mock_load_private_key_from_file,
547+
mock_security_token_signer
548+
):
549+
config = {
550+
"fingerprint": "test_fingerprint",
551+
"tenancy": "test_tenancy",
552+
"region": "us-ashburn-1",
553+
"key_file": "test_key_file",
554+
"generic_headers": [1,2,3],
555+
"body_headers": [4,5,6]
556+
}
557+
558+
with pytest.raises(
559+
ValueError,
560+
match="Parameter `security_token_file` must be provided for using `security_token` authentication."
561+
):
562+
signer = security_token(
563+
oci_config=config,
564+
client_kwargs={"test_client_key":"test_client_value"}
565+
)
566+
567+
config["security_token_file"] = "test_security_token"
568+
mock_security_token_signer.return_value = None
569+
signer = security_token(
570+
oci_config=config,
571+
client_kwargs={"test_client_key":"test_client_value"}
572+
)
573+
574+
mock_validate_and_refresh_token.assert_called_with("test_security_token")
575+
mock_load_private_key_from_file.assert_called_with("test_key_file")
576+
assert signer["client_kwargs"] == {"test_client_key": "test_client_value"}
577+
assert "additional_user_agent" in signer["config"]
578+
assert signer["config"]["fingerprint"] == "test_fingerprint"
579+
assert signer["config"]["tenancy"] == "test_tenancy"
580+
assert signer["config"]["region"] == "us-ashburn-1"
581+
assert signer["config"]["security_token_file"] == "test_security_token"
582+
assert signer["config"]["key_file"] == "test_key_file"
583+
assert isinstance(signer["signer"], SecurityTokenSigner)
584+
585+
@mock.patch("os.system")
586+
@mock.patch("oci.auth.security_token_container.SecurityTokenContainer.get_jwt")
587+
@mock.patch("time.time")
588+
@mock.patch("oci.auth.security_token_container.SecurityTokenContainer.valid")
589+
@mock.patch("oci.auth.security_token_container.SecurityTokenContainer.__init__")
590+
@mock.patch("ads.common.auth.SecurityToken._read_security_token_file")
591+
def test_validate_and_refresh_token(
592+
self,
593+
mock_read_security_token_file,
594+
mock_security_token_container,
595+
mock_valid,
596+
mock_time,
597+
mock_get_jwt,
598+
mock_system
599+
):
600+
security_token = SecurityToken(
601+
args={
602+
"oci_config_location": DEFAULT_LOCATION,
603+
"oci_key_profile": "test_profile"
604+
}
605+
)
606+
mock_security_token_container.return_value = None
607+
608+
mock_valid.return_value = False
609+
with pytest.raises(
610+
TokenExpiredError,
611+
match="Security token has expired. Call `oci session authenticate` to generate new session."
612+
):
613+
security_token._validate_and_refresh_token("test_security_token")
614+
615+
616+
mock_valid.return_value = True
617+
mock_time.return_value = 1
618+
mock_get_jwt.return_value = {"exp" : 1}
619+
620+
security_token._validate_and_refresh_token("test_security_token")
621+
622+
mock_read_security_token_file.assert_called_with("test_security_token")
623+
mock_security_token_container.assert_called()
624+
mock_time.assert_called()
625+
mock_get_jwt.assert_called()
626+
mock_system.assert_called_with("oci session refresh --profile test_profile")
627+
628+
@mock.patch("builtins.open")
629+
@mock.patch("os.path.isfile")
630+
def test_read_security_token_file(self, mock_isfile, mock_open):
631+
security_token = SecurityToken(args={})
632+
633+
mock_isfile.return_value = False
634+
with pytest.raises(
635+
ValueError,
636+
match="Invalid `security_token_file`. Specify a valid path."
637+
):
638+
security_token._read_security_token_file("test_security_token")
639+
640+
mock_isfile.return_value = True
641+
security_token._read_security_token_file("test_security_token")
642+
mock_open.assert_called()

0 commit comments

Comments
 (0)