Skip to content

Commit 16ff2a5

Browse files
committed
Added support for private key content for set_auth().
1 parent 9082e98 commit 16ff2a5

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

ads/common/auth.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def set_auth(
117117
>>> ads.set_auth("api_key", oci_config_location = "other_config_location") # use non-default oci_config_location
118118
119119
>>> ads.set_auth("api_key", client_kwargs={"timeout": 60}) # default signer with connection and read timeouts set to 60 seconds for the client.
120-
120+
>>> ads.set_auth("api_key", )
121121
>>> other_config = oci.config.from_file("other_config_location", "OTHER_PROFILE") # Create non-default config
122122
>>> ads.set_auth(config=other_config) # Set api keys type of authentication based on provided config
123123
@@ -157,7 +157,7 @@ def set_auth(
157157

158158
auth_state.oci_config = config
159159
auth_state.oci_key_profile = profile
160-
if auth == AuthType.API_KEY and not signer and not signer_callable:
160+
if auth == AuthType.API_KEY and not signer and not signer_callable and not signer_kwargs:
161161
if os.path.exists(os.path.expanduser(oci_config_location)):
162162
auth_state.oci_config_path = oci_config_location
163163
else:
@@ -175,6 +175,7 @@ def api_keys(
175175
oci_config: str = os.path.join(os.path.expanduser("~"), ".oci", "config"),
176176
profile: str = DEFAULT_PROFILE,
177177
client_kwargs: Dict = None,
178+
kwargs: Dict = None
178179
) -> Dict:
179180
"""
180181
Prepares authentication and extra arguments necessary for creating clients for different OCI services using API
@@ -188,6 +189,15 @@ def api_keys(
188189
Profile name to select from the config file.
189190
client_kwargs: Optional[Dict], default None
190191
kwargs that are required to instantiate the Client if we need to override the defaults.
192+
kwargs:
193+
kwargs for API authentication signer.
194+
- user: OCID of the user calling the API.
195+
- tenancy: OCID of user's tenancy.
196+
- fingerprint: Fingerprint for the public key that was added to this user.
197+
- region: An Oracle Cloud Infrastructure region.
198+
- pass_phrase: Passphrase used for the key, if it is encrypted.
199+
- key_file: Full path and filename of the private key.
200+
- key_content: The private key as PEM string.
191201
192202
Returns
193203
-------
@@ -208,6 +218,7 @@ def api_keys(
208218
oci_config_location=oci_config,
209219
oci_key_profile=profile,
210220
client_kwargs=client_kwargs,
221+
signer_kwargs=kwargs,
211222
)
212223
signer_generator = AuthFactory().signerGenerator(AuthType.API_KEY)
213224
return signer_generator(signer_args).create_signer()
@@ -316,6 +327,7 @@ def create_signer(
316327
oci_config_location=oci_config_location,
317328
oci_key_profile=profile,
318329
oci_config=config,
330+
signer_kwargs=signer_kwargs,
319331
client_kwargs=client_kwargs,
320332
)
321333
if config:
@@ -386,6 +398,7 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
386398
oci_config_location=auth_state.oci_config_path,
387399
oci_key_profile=auth_state.oci_key_profile,
388400
oci_config=auth_state.oci_config,
401+
signer_kwargs=auth_state.oci_signer_kwargs or {},
389402
client_kwargs={
390403
**(auth_state.oci_client_kwargs or {}),
391404
**(client_kwargs or {}),
@@ -470,11 +483,13 @@ def __init__(self, args: Optional[Dict] = None):
470483
- oci_config_location - path to config file
471484
- oci_key_profile - the profile to load from config file
472485
- client_kwargs - optional parameters for OCI client creation in next steps
486+
- signer_kwargs - optional parameters for signer
473487
"""
474488
self.oci_config = args.get("oci_config")
475489
self.oci_config_location = args.get("oci_config_location")
476490
self.oci_key_profile = args.get("oci_key_profile")
477491
self.client_kwargs = args.get("client_kwargs")
492+
self.signer_kwargs = args.get("signer_kwargs")
478493

479494
def create_signer(self) -> Dict:
480495
"""
@@ -503,18 +518,27 @@ def create_signer(self) -> Dict:
503518
if self.oci_config:
504519
configuration = ads.telemetry.update_oci_client_config(self.oci_config)
505520
else:
506-
configuration = ads.telemetry.update_oci_client_config(
507-
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
508-
)
521+
try:
522+
configuration = ads.telemetry.update_oci_client_config(
523+
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
524+
)
525+
except:
526+
if not os.path.exists(os.path.expanduser(self.oci_config_location)):
527+
logger.info(f"Failed to get config from folder {self.oci_config_location}. Using 'signer_kwargs' instead.")
528+
configuration = ads.telemetry.update_oci_client_config(self.signer_kwargs)
529+
else:
530+
raise
531+
509532
logger.info(f"Using 'api_key' authentication.")
510533
return {
511534
"config": configuration,
512535
"signer": oci.signer.Signer(
513-
configuration["tenancy"],
514-
configuration["user"],
515-
configuration["fingerprint"],
516-
configuration["key_file"],
536+
configuration.get("tenancy"),
537+
configuration.get("user"),
538+
configuration.get("fingerprint"),
539+
configuration.get("key_file"),
517540
configuration.get("pass_phrase"),
541+
configuration.get("key_content")
518542
),
519543
"client_kwargs": self.client_kwargs,
520544
}

tests/unitary/default_setup/auth/test_auth.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
)
3232
from ads.common.oci_logging import OCILog
3333

34+
MOCK_CONFIG_FROM_FILE = {
35+
"user": "test_user",
36+
"fingerprint": "test_fingerprint",
37+
"tenancy": "test_tenancy",
38+
"region": "us-ashburn-1",
39+
"key_file": "test_key_file"
40+
}
41+
3442

3543
class TestEDAMixin(TestCase):
3644
@mock.patch("oci.config.from_file")
@@ -39,6 +47,7 @@ class TestEDAMixin(TestCase):
3947
def test_set_auth_overwrite_profile(
4048
self, mock_load_key_file, mock_path_exists, mock_config_from_file
4149
):
50+
mock_config_from_file.return_value = MOCK_CONFIG_FROM_FILE
4251
set_auth(profile="TEST")
4352
default_signer()
4453
mock_config_from_file.assert_called_with("~/.oci/config", "TEST")
@@ -50,6 +59,7 @@ def test_set_auth_overwrite_profile(
5059
def test_set_auth_overwrite_config_location(
5160
self, mock_load_key_file, mock_path_exists, mock_config_from_file
5261
):
62+
mock_config_from_file.return_value = MOCK_CONFIG_FROM_FILE
5363
mock_path_exists.return_value = True
5464
set_auth(oci_config_location="test_path")
5565
default_signer()
@@ -85,6 +95,25 @@ def test_resource_principal(self, mock_rp_signer):
8595
resource_principal()
8696
mock_rp_signer.assert_called_once()
8797

98+
@mock.patch("oci.signer.load_private_key")
99+
def test_set_auth_with_kwargs(self, mock_load_private_key):
100+
set_auth(
101+
signer_kwargs={
102+
"user": "test_user",
103+
"fingerprint": "test_fingerprint",
104+
"tenancy": "test_tenancy",
105+
"region": "us-ashburn-1",
106+
"key_content": "test_key_content"
107+
}
108+
)
109+
signer = default_signer()
110+
assert signer["config"]["user"] == "test_user"
111+
assert signer["config"]["fingerprint"] == "test_fingerprint"
112+
assert signer["config"]["tenancy"] == "test_tenancy"
113+
assert signer["config"]["region"] == "us-ashburn-1"
114+
assert signer["config"]["key_content"] == "test_key_content"
115+
assert "additional_user_agent" in signer["config"]
116+
assert signer["signer"] != None
88117

89118
class TestOCIMixin(TestCase):
90119
def tearDown(self) -> None:
@@ -352,6 +381,7 @@ def test_set_auth_multiple_times_with(
352381
that default_signer() returns proper signer based on saved state of auth values within AuthState().
353382
Checking that default_signer() runs two times in a row and returns signer based on AuthState().
354383
"""
384+
mock_config_from_file.return_value = MOCK_CONFIG_FROM_FILE
355385
config = dict(
356386
user="ocid1.user.oc1..<unique_ocid>",
357387
fingerprint="00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00",

0 commit comments

Comments
 (0)