Skip to content

Commit ee0633b

Browse files
authored
Bugfix/fetch ds endpoint in mixin (#514)
1 parent 3cdcea2 commit ee0633b

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

ads/common/oci_datascience.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import os
8+
79
import oci.data_science
8-
from ads.common.oci_mixin import OCIModelMixin
10+
911
from ads.common.decorator.utils import class_or_instance_method
12+
from ads.common.oci_mixin import OCIModelMixin
13+
14+
ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
1015

1116

1217
class OCIDataScienceMixin(OCIModelMixin):
1318
@class_or_instance_method
1419
def init_client(cls, **kwargs) -> oci.data_science.DataScienceClient:
20+
client_kwargs = kwargs.get("client_kwargs", {})
21+
if os.environ.get(ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT):
22+
client_kwargs.update(
23+
dict(service_endpoint=os.environ.get(ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT))
24+
)
25+
kwargs.update(client_kwargs)
1526
return cls._init_client(client=oci.data_science.DataScienceClient, **kwargs)
1627

1728
@property

ads/common/oci_mixin.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
"""Contains Mixins for integrating OCI data models
@@ -13,18 +13,19 @@
1313
import re
1414
import traceback
1515
from datetime import date, datetime
16-
from typing import Callable, Optional, Union
1716
from enum import Enum
17+
from typing import Callable, Optional, Union
1818

1919
import oci
2020
import yaml
21+
from dateutil import tz
22+
from dateutil.parser import parse
23+
from oci._vendor import six
24+
2125
from ads.common import auth
2226
from ads.common.decorator.utils import class_or_instance_method
2327
from ads.common.utils import camel_to_snake
2428
from ads.config import COMPARTMENT_OCID
25-
from dateutil import tz
26-
from dateutil.parser import parse
27-
from oci._vendor import six
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -273,7 +274,7 @@ def deserialize(cls, data, to_cls):
273274
else:
274275
return cls.__deserialize_model(data, to_cls)
275276

276-
@classmethod
277+
@class_or_instance_method
277278
def __deserialize_model(cls, data, to_cls):
278279
"""De-serializes list or dict to model."""
279280
if isinstance(data, to_cls):

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
"""Unit tests for model frameworks. Includes tests for:
@@ -1153,15 +1153,15 @@ def test_update_deployment_class_level(
11531153
def test_update_deployment_instance_level_with_id(
11541154
self, mock_client, mock_signer, mock_update
11551155
):
1156+
mock_signer.return_value = {}
11561157
test_model_deployment_id = "xxxx.datasciencemodeldeployment.xxxx"
11571158
md_props = ModelDeploymentProperties(model_id=test_model_deployment_id)
11581159
md = ModelDeployment(properties=md_props)
11591160

11601161
test_model = MagicMock(model_deployment=md, _summary_status=SummaryStatus())
11611162
mock_update.return_value = test_model
11621163

1163-
generic_model = GenericModel(estimator=TestEstimator())
1164-
test_result = generic_model.update_deployment(
1164+
test_result = self.generic_model.update_deployment(
11651165
model_deployment_id=test_model_deployment_id,
11661166
properties=None,
11671167
wait_for_completion=True,
@@ -1430,11 +1430,10 @@ def test_restart_deployment(
14301430
test_model_deployment_id = "xxxx.datasciencemodeldeployment.xxxx"
14311431
md_props = ModelDeploymentProperties(model_id=test_model_deployment_id)
14321432
md = ModelDeployment(properties=md_props)
1433-
generic_model = GenericModel(estimator=TestEstimator())
1434-
generic_model.model_deployment = md
1433+
self.generic_model.model_deployment = md
14351434
mock_deactivate.return_value = md
14361435
mock_activate.return_value = md
1437-
generic_model.restart_deployment(max_wait_time=2000, poll_interval=50)
1436+
self.generic_model.restart_deployment(max_wait_time=2000, poll_interval=50)
14381437
mock_deactivate.assert_called_with(max_wait_time=2000, poll_interval=50)
14391438
mock_activate.assert_called_with(max_wait_time=2000, poll_interval=50)
14401439

0 commit comments

Comments
 (0)