Skip to content

Commit 79722d0

Browse files
mrDzurbcodeloopahoslerprasankhgovarsha
authored
Forecasting Operator (#268)
Co-authored-by: Vikas Pandey <vikas.v.pandey@oracle.com> Co-authored-by: Allen Hosler <allen.hosler@oracle.com> Co-authored-by: Prashant Sankhla <prashant.s.sankhla@oracle.com> Co-authored-by: Vikas Pandey <vikaspandey707@gmail.com> Co-authored-by: Prashant Sankhla <sankhlaprashant15@gmail.com> Co-authored-by: Goalla Varsha <goalla.v.varsha@oracle.com> Co-authored-by: MING KANG <ming.kang@oracle.com>
1 parent 2f9be39 commit 79722d0

File tree

201 files changed

+15472
-723
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

201 files changed

+15472
-723
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,6 @@ logs/
160160

161161
# vim
162162
*.swp
163+
164+
# Python Wheel
165+
*.whl

MANIFEST.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ exclude build/lib/notebooks/**
2424
exclude benchmark/**
2525
include ads/ads
2626
include ads/model/common/*.*
27+
include ads/operator/**/*.md
28+
include ads/operator/**/*.yaml
29+
include ads/operator/**/*.whl
30+
include ads/operator/**/MLoperator

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ You have various options when installing ADS.
2727
python3 -m pip install oracle-ads
2828
```
2929

30+
### Installing OCI AI Operators
31+
32+
To use the AI Forecast Operator, install the "forecast" dependencies using the following command:
33+
34+
```bash
35+
python3 -m pip install 'oracle_ads[forecast]==2.9.0'
36+
```
37+
3038
### Installing extras libraries
3139

3240
To work with gradient boosting models, install the `boosted` module. This module includes XGBoost and LightGBM model classes.

ads/cli.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import ads.opctl.cli
1515
import ads.jobs.cli
1616
import ads.pipeline.cli
17-
import os
18-
import json
17+
import ads.opctl.operator.cli
1918
except Exception as ex:
2019
print(
2120
"Please run `pip install oracle-ads[opctl]` to install "
22-
"the required dependencies for ADS CLI."
21+
"the required dependencies for ADS CLI. \n"
22+
f"{str(ex)}"
2323
)
2424
logger.debug(ex)
2525
logger.debug(traceback.format_exc())
@@ -44,6 +44,7 @@ def cli():
4444
cli.add_command(ads.opctl.cli.commands)
4545
cli.add_command(ads.jobs.cli.commands)
4646
cli.add_command(ads.pipeline.cli.commands)
47+
cli.add_command(ads.opctl.operator.cli.commands)
4748

4849

4950
if __name__ == "__main__":

ads/common/auth.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ def create_signer(self) -> Dict:
629629
user=configuration["user"],
630630
fingerprint=configuration["fingerprint"],
631631
private_key_file_location=configuration.get("key_file"),
632-
pass_phrase= configuration.get("pass_phrase"),
633-
private_key_content=configuration.get("key_content")
632+
pass_phrase=configuration.get("pass_phrase"),
633+
private_key_content=configuration.get("key_content"),
634634
),
635635
"client_kwargs": self.client_kwargs,
636636
}
@@ -750,21 +750,10 @@ class SecurityToken(AuthSignerGenerator):
750750
a given user - it requires that user's private key and security token.
751751
It prepares extra arguments necessary for creating clients for variety of OCI services.
752752
"""
753-
SECURITY_TOKEN_GENERIC_HEADERS = [
754-
"date",
755-
"(request-target)",
756-
"host"
757-
]
758-
SECURITY_TOKEN_BODY_HEADERS = [
759-
"content-length",
760-
"content-type",
761-
"x-content-sha256"
762-
]
763-
SECURITY_TOKEN_REQUIRED = [
764-
"security_token_file",
765-
"key_file",
766-
"region"
767-
]
753+
754+
SECURITY_TOKEN_GENERIC_HEADERS = ["date", "(request-target)", "host"]
755+
SECURITY_TOKEN_BODY_HEADERS = ["content-length", "content-type", "x-content-sha256"]
756+
SECURITY_TOKEN_REQUIRED = ["security_token_file", "key_file", "region"]
768757

769758
def __init__(self, args: Optional[Dict] = None):
770759
"""
@@ -831,12 +820,18 @@ def create_signer(self) -> Dict:
831820
return {
832821
"config": configuration,
833822
"signer": oci.auth.signers.SecurityTokenSigner(
834-
token=self._read_security_token_file(configuration.get("security_token_file")),
823+
token=self._read_security_token_file(
824+
configuration.get("security_token_file")
825+
),
835826
private_key=oci.signer.load_private_key_from_file(
836827
configuration.get("key_file"), configuration.get("pass_phrase")
837828
),
838-
generic_headers=configuration.get("generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS),
839-
body_headers=configuration.get("body_headers", self.SECURITY_TOKEN_BODY_HEADERS)
829+
generic_headers=configuration.get(
830+
"generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS
831+
),
832+
body_headers=configuration.get(
833+
"body_headers", self.SECURITY_TOKEN_BODY_HEADERS
834+
),
840835
),
841836
"client_kwargs": self.client_kwargs,
842837
}
@@ -849,30 +844,37 @@ def _validate_and_refresh_token(self, configuration: Dict[str, Any]):
849844
configuration: Dict
850845
Security token configuration.
851846
"""
852-
security_token = self._read_security_token_file(configuration.get("security_token_file"))
853-
security_token_container = oci.auth.security_token_container.SecurityTokenContainer(
854-
session_key_supplier=None,
855-
security_token=security_token
847+
security_token = self._read_security_token_file(
848+
configuration.get("security_token_file")
849+
)
850+
security_token_container = (
851+
oci.auth.security_token_container.SecurityTokenContainer(
852+
session_key_supplier=None, security_token=security_token
853+
)
856854
)
857855

858856
if not security_token_container.valid():
859857
raise SecurityTokenError(
860858
"Security token has expired. Call `oci session authenticate` to generate new session."
861859
)
862-
860+
863861
time_now = int(time.time())
864862
time_expired = security_token_container.get_jwt()["exp"]
865863
if time_expired - time_now < SECURITY_TOKEN_LEFT_TIME:
866864
if not self.oci_config_location:
867-
logger.warning("Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer().")
865+
logger.warning(
866+
"Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer()."
867+
)
868868
else:
869-
result = os.system(f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}")
869+
result = os.system(
870+
f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}"
871+
)
870872
if result == 1:
871873
logger.warning(
872874
"Some error happened during auto-refreshing the token. Continue using the current one that's expiring in less than {SECURITY_TOKEN_LEFT_TIME} seconds."
873875
"Please follow steps in https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/clitoken.htm to renew token."
874876
)
875-
877+
876878
date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
877879
logger.info(f"Session is valid until {date_time}.")
878880

@@ -894,7 +896,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:
894896
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
895897
try:
896898
token = None
897-
with open(expanded_path, 'r') as f:
899+
with open(expanded_path, "r") as f:
898900
token = f.read()
899901
return token
900902
except:
@@ -903,7 +905,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:
903905

904906
class AuthFactory:
905907
"""
906-
AuthFactory class which contains list of registered signers and alllows to register new signers.
908+
AuthFactory class which contains list of registered signers and allows to register new signers.
907909
Check documentation for more signers: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html.
908910
909911
Current signers:

ads/common/decorator/runtime_dependency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class OptionalDependency:
6464
OPTUNA = "oracle-ads[optuna]"
6565
SPARK = "oracle-ads[spark]"
6666
HUGGINGFACE = "oracle-ads[huggingface]"
67+
FORECAST = "oracle-ads[forecast]"
68+
PII = "oracle-ads[pii]"
6769
FEATURE_STORE = "oracle-ads[feature-store]"
6870
GRAPHVIZ = "oracle-ads[graphviz]"
6971
MLM_INSIGHTS = "oracle-ads[mlm_insights]"

ads/common/object_storage_details.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import json
@@ -15,7 +15,7 @@
1515
from ads.common import oci_client
1616

1717

18-
class InvalidObjectStoragePath(Exception): # pragma: no cover
18+
class InvalidObjectStoragePath(Exception): # pragma: no cover
1919
"""Invalid Object Storage Path."""
2020

2121
pass
@@ -137,4 +137,4 @@ def is_oci_path(uri: str = None) -> bool:
137137
"""
138138
if not uri:
139139
return False
140-
return uri.startswith("oci://")
140+
return uri.lower().startswith("oci://")

ads/common/serializer.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
"""
8+
This module provides a base class for serializable items, as well as methods for serializing and
9+
deserializing objects to and from JSON and YAML formats. It also includes methods for reading and
10+
writing serialized objects to and from files.
11+
"""
12+
713
import dataclasses
814
import json
915
from abc import ABC, abstractmethod
@@ -271,11 +277,16 @@ def from_yaml(
271277
272278
Parameters
273279
----------
274-
yaml_string (string, optional): YAML string. Defaults to None.
275-
uri (string, optional): URI location of file containing YAML string. Defaults to None.
276-
loader (callable, optional): Custom YAML loader. Defaults to CLoader/SafeLoader.
277-
kwargs (dict): keyword arguments to be passed into fsspec.open(). For OCI object storage, this should be config="path/to/.oci/config".
278-
For other storage connections consider e.g. host, port, username, password, etc.
280+
yaml_string (string, optional)
281+
YAML string. Defaults to None.
282+
uri (string, optional)
283+
URI location of file containing YAML string. Defaults to None.
284+
loader (callable, optional)
285+
Custom YAML loader. Defaults to CLoader/SafeLoader.
286+
kwargs (dict)
287+
keyword arguments to be passed into fsspec.open().
288+
For OCI object storage, this should be config="path/to/.oci/config".
289+
For other storage connections consider e.g. host, port, username, password, etc.
279290
280291
Raises
281292
------
@@ -288,10 +299,10 @@ def from_yaml(
288299
Returns instance of the class
289300
"""
290301
if yaml_string:
291-
return cls.from_dict(yaml.load(yaml_string, Loader=loader))
302+
return cls.from_dict(yaml.load(yaml_string, Loader=loader), **kwargs)
292303
if uri:
293304
yaml_dict = yaml.load(cls._read_from_file(uri=uri, **kwargs), Loader=loader)
294-
return cls.from_dict(yaml_dict)
305+
return cls.from_dict(yaml_dict, **kwargs)
295306
raise ValueError("Must provide either YAML string or URI location")
296307

297308
@classmethod
@@ -345,8 +356,8 @@ class DataClassSerializable(Serializable):
345356
Returns an instance of the class instantiated from the dictionary provided.
346357
"""
347358

348-
@staticmethod
349-
def _validate_dict(obj_dict: Dict) -> bool:
359+
@classmethod
360+
def _validate_dict(cls, obj_dict: Dict) -> bool:
350361
"""validate the dictionary.
351362
352363
Parameters
@@ -379,7 +390,7 @@ def to_dict(self, **kwargs) -> Dict:
379390
obj_dict = dataclasses.asdict(self)
380391
if "side_effect" in kwargs and kwargs["side_effect"]:
381392
obj_dict = DataClassSerializable._normalize_dict(
382-
obj_dict=obj_dict, case=kwargs["side_effect"]
393+
obj_dict=obj_dict, case=kwargs["side_effect"], recursively=True
383394
)
384395
return obj_dict
385396

@@ -388,6 +399,8 @@ def from_dict(
388399
cls,
389400
obj_dict: dict,
390401
side_effect: Optional[SideEffect] = SideEffect.CONVERT_KEYS_TO_LOWER.value,
402+
ignore_unknown: Optional[bool] = False,
403+
**kwargs,
391404
) -> "DataClassSerializable":
392405
"""Returns an instance of the class instantiated by the dictionary provided.
393406
@@ -399,6 +412,8 @@ def from_dict(
399412
side effect to take on the dictionary. The side effect can be either
400413
convert the dictionary keys to "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
401414
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value) cases.
415+
ignore_unknown: (bool, optional). Defaults to `False`.
416+
Whether to ignore unknown fields or not.
402417
403418
Returns
404419
-------
@@ -415,25 +430,36 @@ def from_dict(
415430

416431
allowed_fields = set([f.name for f in dataclasses.fields(cls)])
417432
wrong_fields = set(obj_dict.keys()) - allowed_fields
418-
if wrong_fields:
433+
if wrong_fields and not ignore_unknown:
419434
logger.warning(
420435
f"The class {cls.__name__} doesn't contain attributes: `{list(wrong_fields)}`. "
421436
"These fields will be ignored."
422437
)
423438

424-
obj = cls(**{key: obj_dict[key] for key in allowed_fields})
439+
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})
425440

426441
for key, value in obj_dict.items():
427-
if isinstance(value, dict) and hasattr(
428-
getattr(cls(), key).__class__, "from_dict"
442+
if (
443+
key in allowed_fields
444+
and isinstance(value, dict)
445+
and hasattr(getattr(cls(), key).__class__, "from_dict")
429446
):
430-
attribute = getattr(cls(), key).__class__.from_dict(value)
447+
attribute = getattr(cls(), key).__class__.from_dict(
448+
value,
449+
ignore_unknown=ignore_unknown,
450+
side_effect=side_effect,
451+
**kwargs,
452+
)
431453
setattr(obj, key, attribute)
454+
432455
return obj
433456

434457
@staticmethod
435458
def _normalize_dict(
436-
obj_dict: Dict, case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value
459+
obj_dict: Dict,
460+
recursively: bool = False,
461+
case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value,
462+
**kwargs,
437463
) -> Dict:
438464
"""lower all the keys.
439465
@@ -444,6 +470,8 @@ def _normalize_dict(
444470
case: (optional, str). Defaults to "lower".
445471
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
446472
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).
473+
recursively: (bool, optional). Defaults to `False`.
474+
Whether to recursively normalize the dictionary or not.
447475
448476
Returns
449477
-------
@@ -452,12 +480,16 @@ def _normalize_dict(
452480
"""
453481
normalized_obj_dict = {}
454482
for key, value in obj_dict.items():
455-
if isinstance(value, dict):
483+
if recursively and isinstance(value, dict):
456484
value = DataClassSerializable._normalize_dict(
457-
value, case=SideEffect.CONVERT_KEYS_TO_UPPER.value
485+
value, case=case, recursively=recursively, **kwargs
458486
)
459487
normalized_obj_dict = DataClassSerializable._normalize_key(
460-
normalized_obj_dict=normalized_obj_dict, key=key, value=value, case=case
488+
normalized_obj_dict=normalized_obj_dict,
489+
key=key,
490+
value=value,
491+
case=case,
492+
**kwargs,
461493
)
462494
return normalized_obj_dict
463495

@@ -467,7 +499,7 @@ def _normalize_key(
467499
) -> Dict:
468500
"""helper function to normalize the key in the case specified and add it back to the dictionary.
469501
470-
Paramaters
502+
Parameters
471503
----------
472504
normalized_obj_dict: (Dict)
473505
the dictionary to append the key and value to.
@@ -476,17 +508,18 @@ def _normalize_key(
476508
value: (Union[str, Dict])
477509
value to be added.
478510
case: (str)
479-
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
511+
The case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
480512
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).
481513
482514
Raises
483515
------
484-
NotImplementedError: if case provided is not either "lower" or "upper".
516+
NotImplementedError
517+
Raised when `case` is not supported.
485518
486519
Returns
487520
-------
488521
Dict
489-
normalized dictionary with the key and value added in the case specified.
522+
Normalized dictionary with the key and value added in the case specified.
490523
"""
491524
if case.lower() == SideEffect.CONVERT_KEYS_TO_LOWER.value:
492525
normalized_obj_dict[key.lower()] = value

0 commit comments

Comments
 (0)