Skip to content

Commit df547b0

Browse files
authored
ODSC-44564: Improving fs cli endpoint interface (#277)
2 parents 3fb2e08 + 8cc70f6 commit df547b0

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

ads/feature_store/mixin/oci_feature_store.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,70 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import logging
8+
import os
9+
10+
logger = logging.getLogger(__name__)
711
from ads.common.oci_mixin import OCIModelMixin
812
import oci.feature_store
9-
import os
13+
import yaml
14+
15+
16+
try:
17+
from yaml import CDumper as dumper
18+
from yaml import CLoader as loader
19+
except:
20+
from yaml import Dumper as dumper
21+
from yaml import Loader as loader
22+
23+
try:
24+
from odsc_cli.utils import user_fs_config_loc, FsTemplate
25+
except ImportError:
26+
pass
1027

1128

1229
class OCIFeatureStoreMixin(OCIModelMixin):
30+
__mod_time = 0
31+
__template: "FsTemplate" = None
32+
FS_SERVICE_ENDPOINT = "fs_service_endpoint"
33+
SERVICE_ENDPOINT = "service_endpoint"
34+
1335
@classmethod
1436
def init_client(
1537
cls, **kwargs
1638
) -> oci.feature_store.feature_store_client.FeatureStoreClient:
17-
# TODO: Getting the endpoint from authorizer
18-
fs_service_endpoint = os.environ.get("OCI_FS_SERVICE_ENDPOINT")
39+
default_kwargs: dict = cls._get_auth().get("client_kwargs", {})
40+
41+
fs_service_endpoint = (
42+
kwargs.get(cls.FS_SERVICE_ENDPOINT, None)
43+
or kwargs.get(cls.SERVICE_ENDPOINT, None)
44+
or default_kwargs.get(cls.FS_SERVICE_ENDPOINT, None)
45+
)
46+
47+
if not fs_service_endpoint:
48+
try:
49+
mod_time = os.stat(user_fs_config_loc()).st_mtime
50+
if mod_time > cls.__mod_time:
51+
with open(user_fs_config_loc()) as ccf:
52+
cls.__template = FsTemplate(yaml.load(ccf, Loader=loader))
53+
cls.__mod_time = mod_time
54+
except NameError:
55+
logger.info(
56+
"%s",
57+
"Feature store configuration helpers are missing. "
58+
"Support for reading service endpoint from config file is disabled",
59+
)
60+
except FileNotFoundError:
61+
logger.info(
62+
"%s",
63+
"ODSC cli config for feature store was not found",
64+
)
65+
pass
66+
if cls.__template:
67+
fs_service_endpoint = cls.__template.service_endpoint
68+
1969
if fs_service_endpoint:
20-
kwargs = {"service_endpoint": fs_service_endpoint}
70+
kwargs[cls.SERVICE_ENDPOINT] = fs_service_endpoint
2171

2272
client = cls._init_client(
2373
client=oci.feature_store.feature_store_client.FeatureStoreClient, **kwargs
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import ads
8+
from ads.feature_store.mixin.oci_feature_store import OCIFeatureStoreMixin
9+
from ads.common.auth import AuthType
10+
11+
TEST_URL_1 = "https://test1.com"
12+
TEST_URL_2 = "https://test2.com"
13+
TEST_URL_3 = "https://test3.com"
14+
TEST_URL_4 = "https://test4.com"
15+
16+
17+
def test_global_service_endpoint():
18+
ads.set_auth(auth=AuthType.API_KEY, client_kwargs={"service_endpoint": TEST_URL_1})
19+
client = OCIFeatureStoreMixin.init_client()
20+
assert client.base_client.endpoint == f"{TEST_URL_1}/20230101"
21+
22+
23+
def test_global_service_and_fs_endpoints():
24+
ads.set_auth(
25+
auth=AuthType.API_KEY,
26+
client_kwargs={
27+
"fs_service_endpoint": TEST_URL_1,
28+
"service_endpoint": TEST_URL_2,
29+
},
30+
)
31+
client = OCIFeatureStoreMixin.init_client()
32+
assert client.base_client.endpoint == f"{TEST_URL_1}/20230101"
33+
34+
35+
def test_override_service_endpoint():
36+
ads.set_auth(auth=AuthType.API_KEY)
37+
client = OCIFeatureStoreMixin.init_client(service_endpoint=TEST_URL_1)
38+
assert client.base_client.endpoint == f"{TEST_URL_1}/20230101"
39+
40+
41+
def test_override_service_and_fs_endpoints():
42+
ads.set_auth(auth=AuthType.API_KEY)
43+
client = OCIFeatureStoreMixin.init_client(
44+
service_endpoint=TEST_URL_1, fs_service_endpoint=TEST_URL_2
45+
)
46+
assert client.base_client.endpoint == f"{TEST_URL_2}/20230101"
47+
48+
49+
def test_override_service_and_fs_endpoints_with_global_service_and_fs_endpoints():
50+
ads.set_auth(
51+
auth=AuthType.API_KEY,
52+
client_kwargs={
53+
"fs_service_endpoint": TEST_URL_3,
54+
"service_endpoint": TEST_URL_4,
55+
},
56+
)
57+
client = OCIFeatureStoreMixin.init_client(
58+
service_endpoint=TEST_URL_1, fs_service_endpoint=TEST_URL_2
59+
)
60+
assert client.base_client.endpoint == f"{TEST_URL_2}/20230101"

0 commit comments

Comments
 (0)