Skip to content

Commit 043e636

Browse files
committed
ODSC-44564: Improving fs cli endpoint interface
1 parent ba93f40 commit 043e636

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

ads/feature_store/mixin/oci_feature_store.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,66 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import logging
8+
import os
9+
10+
logger = logging.getLogger(__name__)
711
from ads.common.oci_mixin import OCIModelMixin
812
import oci.feature_store
9-
import os
13+
import yaml
14+
15+
16+
try:
17+
from yaml import CDumper as dumper
18+
from yaml import CLoader as loader
19+
except:
20+
from yaml import Dumper as dumper
21+
from yaml import Loader as loader
22+
23+
try:
24+
from odsc_cli.utils import user_fs_config_loc, FsTemplate
25+
except ImportError:
26+
pass
1027

1128

1229
class OCIFeatureStoreMixin(OCIModelMixin):
30+
__mod_time = 0
31+
__template: "FsTemplate" = None
32+
FS_SERVICE_ENDPOINT = "fs_service_endpoint"
33+
SERVICE_ENDPOINT = "service_endpoint"
34+
1335
@classmethod
1436
def init_client(
1537
cls, **kwargs
1638
) -> oci.feature_store.feature_store_client.FeatureStoreClient:
17-
# TODO: Getting the endpoint from authorizer
18-
fs_service_endpoint = os.environ.get("OCI_FS_SERVICE_ENDPOINT")
39+
default_kwargs: dict = cls._get_auth().get("client_kwargs", {})
40+
41+
fs_service_endpoint = (
42+
kwargs.get(cls.FS_SERVICE_ENDPOINT, None)
43+
or kwargs.get(cls.SERVICE_ENDPOINT, None)
44+
or default_kwargs.get(cls.FS_SERVICE_ENDPOINT, None)
45+
)
46+
47+
if not fs_service_endpoint:
48+
try:
49+
mod_time = os.stat(user_fs_config_loc()).st_mtime
50+
if mod_time > cls.__mod_time:
51+
with open(user_fs_config_loc()) as ccf:
52+
cls.__template = FsTemplate(yaml.load(ccf, Loader=loader))
53+
cls.__mod_time = mod_time
54+
except NameError:
55+
logger.info(
56+
"%s",
57+
"Feature store configuration helpers are missing. "
58+
"Support for reading service endpoint from config file is disabled",
59+
)
60+
except FileNotFoundError:
61+
pass
62+
if cls.__template:
63+
fs_service_endpoint = cls.__template.service_endpoint
64+
1965
if fs_service_endpoint:
20-
kwargs = {"service_endpoint": fs_service_endpoint}
66+
kwargs[cls.SERVICE_ENDPOINT] = fs_service_endpoint
2167

2268
client = cls._init_client(
2369
client=oci.feature_store.feature_store_client.FeatureStoreClient, **kwargs
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import ads
8+
from ads.feature_store.mixin.oci_feature_store import OCIFeatureStoreMixin
9+
from ads.common.auth import AuthType
10+
11+
TEST_URL = "https://test.com"
12+
13+
14+
def test_manual_endpoint():
15+
ads.set_auth(auth=AuthType.API_KEY, client_kwargs={"fs_service_endpoint": TEST_URL})
16+
client = OCIFeatureStoreMixin.init_client(fs_service_endpoint=TEST_URL)
17+
assert client.base_client.endpoint == f"{TEST_URL}/20230101"
18+
19+
20+
def test_manual_with_service_endpoint():
21+
ads.set_auth(
22+
auth=AuthType.API_KEY,
23+
client_kwargs={
24+
"fs_service_endpoint": TEST_URL,
25+
"service_endpoint": "service.com",
26+
},
27+
)
28+
client = OCIFeatureStoreMixin.init_client(fs_service_endpoint=TEST_URL)
29+
assert client.base_client.endpoint == f"{TEST_URL}/20230101"
30+
31+
32+
def test_service_endpoint():
33+
ads.set_auth(auth=AuthType.API_KEY, client_kwargs={"service_endpoint": TEST_URL})
34+
client = OCIFeatureStoreMixin.init_client(fs_service_endpoint=TEST_URL)
35+
assert client.base_client.endpoint == f"{TEST_URL}/20230101"

0 commit comments

Comments
 (0)