Skip to content

Commit bb2476b

Browse files
Merge branch 'main' into ODSC-65032/md-shapes-api
2 parents 8dc7bba + affeae4 commit bb2476b

File tree

20 files changed

+190
-98
lines changed

20 files changed

+190
-98
lines changed

ads/opctl/operator/common/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -18,23 +17,26 @@
1817
from cerberus import Validator
1918

2019
from ads.opctl import logger, utils
21-
from ads.opctl.operator import __operators__
2220

2321
CONTAINER_NETWORK = "CONTAINER_NETWORK"
2422

2523

2624
class OperatorValidator(Validator):
2725
"""The custom validator class."""
2826

29-
pass
27+
def validate(self, obj_dict, **kwargs):
28+
# Model should be case insensitive
29+
if "model" in obj_dict["spec"]:
30+
obj_dict["spec"]["model"] = str(obj_dict["spec"]["model"]).lower()
31+
return super().validate(obj_dict, **kwargs)
3032

3133

3234
def create_output_folder(name):
3335
output_folder = name
3436
protocol = fsspec.utils.get_protocol(output_folder)
3537
storage_options = {}
3638
if protocol != "file":
37-
storage_options = auth or default_signer()
39+
storage_options = default_signer()
3840

3941
fs = fsspec.filesystem(protocol, **storage_options)
4042
name_suffix = 1

ads/opctl/operator/lowcode/anomaly/model/base_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ def generate_report(self):
166166
yaml_appendix = rc.Yaml(self.config.to_dict())
167167
summary = rc.Block(
168168
rc.Group(
169-
rc.Text(
170-
f"You selected the **`{self.spec.model}`** model.\n{model_description.text}\n"
171-
),
169+
rc.Text(f"You selected the **`{self.spec.model}`** model.\n"),
170+
model_description,
172171
rc.Text(
173172
"Based on your dataset, you could have also selected "
174173
f"any of the models: `{'`, `'.join(SupportedModels.keys() if self.spec.datetime_column else NonTimeADSupportedModels.keys())}`."

ads/opctl/operator/lowcode/anomaly/model/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ class UnSupportedModelError(Exception):
2626

2727
def __init__(self, operator_config: AnomalyOperatorConfig, model_type: str):
2828
supported_models = (
29-
SupportedModels.values
29+
SupportedModels.values()
3030
if operator_config.spec.datetime_column
31-
else NonTimeADSupportedModels.values
31+
else NonTimeADSupportedModels.values()
3232
)
3333
message = (
3434
f"Model: `{model_type}` is not supported. "

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

6+
from abc import ABC
7+
8+
import pandas as pd
9+
710
from ads.opctl import logger
11+
from ads.opctl.operator.lowcode.common.const import DataColumns
812
from ads.opctl.operator.lowcode.common.errors import (
9-
InvalidParameterError,
1013
DataMismatchError,
14+
InvalidParameterError,
1115
)
12-
from ads.opctl.operator.lowcode.common.const import DataColumns
1316
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
14-
import pandas as pd
15-
from abc import ABC
1617

1718

1819
class Transformations(ABC):
@@ -58,6 +59,7 @@ def run(self, data):
5859
5960
"""
6061
clean_df = self._remove_trailing_whitespace(data)
62+
# clean_df = self._normalize_column_names(clean_df)
6163
if self.name == "historical_data":
6264
self._check_historical_dataset(clean_df)
6365
clean_df = self._set_series_id_column(clean_df)
@@ -95,8 +97,11 @@ def run(self, data):
9597
def _remove_trailing_whitespace(self, df):
9698
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
9799

100+
# def _normalize_column_names(self, df):
101+
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
102+
98103
def _set_series_id_column(self, df):
99-
self._target_category_columns_map = dict()
104+
self._target_category_columns_map = {}
100105
if not self.target_category_columns:
101106
df[DataColumns.Series] = "Series 1"
102107
self.has_artificial_series = True
@@ -125,10 +130,10 @@ def _format_datetime_col(self, df):
125130
df[self.dt_column_name] = pd.to_datetime(
126131
df[self.dt_column_name], format=self.dt_column_format
127132
)
128-
except:
133+
except Exception as ee:
129134
raise InvalidParameterError(
130135
f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
131-
)
136+
) from ee
132137
return df
133138

134139
def _set_multi_index(self, df):
@@ -242,7 +247,6 @@ def _check_historical_dataset(self, df):
242247
"Class": "A",
243248
"Num": 2
244249
},
245-
246250
}
247251
"""
248252

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
import argparse
86
import logging
97
import os
108
import shutil
119
import sys
1210
import tempfile
13-
import time
14-
from string import Template
15-
from typing import Any, Dict, List, Tuple
16-
import pandas as pd
17-
from ads.opctl import logger
18-
import oracledb
11+
from typing import List, Union
1912

2013
import fsspec
21-
import yaml
22-
from typing import Union
14+
import oracledb
15+
import pandas as pd
2316

17+
from ads.common.object_storage_details import ObjectStorageDetails
2418
from ads.opctl import logger
19+
from ads.opctl.operator.common.operator_config import OutputDirectory
2520
from ads.opctl.operator.lowcode.common.errors import (
26-
InputDataError,
2721
InvalidParameterError,
28-
PermissionsError,
29-
DataMismatchError,
3022
)
31-
from ads.opctl.operator.common.operator_config import OutputDirectory
32-
from ads.common.object_storage_details import ObjectStorageDetails
3323
from ads.secrets import ADBSecretKeeper
3424

3525

3626
def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
37-
if fsspec.utils.get_protocol(filename) == "file":
38-
return pd_fn(filename, **kwargs)
39-
elif fsspec.utils.get_protocol(filename) in ["http", "https"]:
27+
if fsspec.utils.get_protocol(filename) == "file" or fsspec.utils.get_protocol(
28+
filename
29+
) in ["http", "https"]:
4030
return pd_fn(filename, **kwargs)
4131

4232
storage_options = storage_options or (
@@ -48,7 +38,7 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
4838

4939
def load_data(data_spec, storage_options=None, **kwargs):
5040
if data_spec is None:
51-
raise InvalidParameterError(f"No details provided for this data source.")
41+
raise InvalidParameterError("No details provided for this data source.")
5242
filename = data_spec.url
5343
format = data_spec.format
5444
columns = data_spec.columns
@@ -67,7 +57,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
6757
if not format:
6858
_, format = os.path.splitext(filename)
6959
format = format[1:]
70-
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
60+
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf", "parquet"]:
7161
read_fn = getattr(pd, f"read_{format}")
7262
data = call_pandas_fsspec(
7363
read_fn, filename, storage_options=storage_options
@@ -84,19 +74,31 @@ def load_data(data_spec, storage_options=None, **kwargs):
8474
with tempfile.TemporaryDirectory() as temp_dir:
8575
if vault_secret_id is not None:
8676
try:
87-
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
88-
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
89-
shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
90-
connect_args['wallet_location'] = temp_dir
91-
if 'user_name' in adwsecret and 'user' not in connect_args:
92-
connect_args['user'] = adwsecret['user_name']
93-
if 'password' in adwsecret and 'password' not in connect_args:
94-
connect_args['password'] = adwsecret['password']
95-
if 'service_name' in adwsecret and 'service_name' not in connect_args:
96-
connect_args['service_name'] = adwsecret['service_name']
77+
with ADBSecretKeeper.load_secret(
78+
vault_secret_id, wallet_dir=temp_dir
79+
) as adwsecret:
80+
if (
81+
"wallet_location" in adwsecret
82+
and "wallet_location" not in connect_args
83+
):
84+
shutil.unpack_archive(
85+
adwsecret["wallet_location"], temp_dir
86+
)
87+
connect_args["wallet_location"] = temp_dir
88+
if "user_name" in adwsecret and "user" not in connect_args:
89+
connect_args["user"] = adwsecret["user_name"]
90+
if "password" in adwsecret and "password" not in connect_args:
91+
connect_args["password"] = adwsecret["password"]
92+
if (
93+
"service_name" in adwsecret
94+
and "service_name" not in connect_args
95+
):
96+
connect_args["service_name"] = adwsecret["service_name"]
9797

9898
except Exception as e:
99-
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")
99+
raise Exception(
100+
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101+
)
100102

101103
con = oracledb.connect(**connect_args)
102104
if table_name is not None:
@@ -105,11 +107,11 @@ def load_data(data_spec, storage_options=None, **kwargs):
105107
data = pd.read_sql(sql, con)
106108
else:
107109
raise InvalidParameterError(
108-
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
110+
"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
109111
)
110112
else:
111113
raise InvalidParameterError(
112-
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
114+
"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
113115
)
114116
if columns:
115117
# keep only these columns, done after load because only CSV supports stream filtering
@@ -232,7 +234,7 @@ def human_time_friendly(seconds):
232234
accumulator.append(
233235
"{} {}{}".format(int(amount), unit, "" if amount == 1 else "s")
234236
)
235-
accumulator.append("{} secs".format(round(seconds, 2)))
237+
accumulator.append(f"{round(seconds, 2)} secs")
236238
return ", ".join(accumulator)
237239

238240

@@ -248,9 +250,7 @@ def find_output_dirname(output_dir: OutputDirectory):
248250
unique_output_dir = f"{output_dir}_{counter}"
249251
counter += 1
250252
logger.warn(
251-
"Since the output directory was not specified, the output will be saved to {} directory.".format(
252-
unique_output_dir
253-
)
253+
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
254254
)
255255
return unique_output_dir
256256

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import logging
5+
import os
56
import traceback
67

78
import numpy as np
@@ -80,10 +81,17 @@ def _build_model(self) -> pd.DataFrame:
8081

8182
from automlx import Pipeline, init
8283

84+
cpu_count = os.cpu_count()
8385
try:
86+
if cpu_count < 4:
87+
engine = "local"
88+
engine_opts = None
89+
else:
90+
engine = "ray"
91+
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
8492
init(
85-
engine="ray",
86-
engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},
93+
engine=engine,
94+
engine_opts=engine_opts,
8795
loglevel=logging.CRITICAL,
8896
)
8997
except Exception as e:

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ def generate_report(self):
148148
header_section = rc.Block(
149149
rc.Heading("Forecast Report", level=1),
150150
rc.Text(
151-
f"You selected the {self.spec.model} model.\n{model_description}\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
151+
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
152152
),
153+
model_description,
153154
rc.Group(
154155
rc.Metric(
155156
heading="Analysis was completed in ",

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .autots import AutoTSOperatorModel
1212
from .base_model import ForecastOperatorBaseModel
1313
from .forecast_datasets import ForecastDatasets
14+
from .ml_forecast import MLForecastOperatorModel
1415
from .neuralprophet import NeuralProphetOperatorModel
1516
from .prophet import ProphetOperatorModel
1617

@@ -19,7 +20,7 @@ class UnSupportedModelError(Exception):
1920
def __init__(self, model_type: str):
2021
super().__init__(
2122
f"Model: `{model_type}` "
22-
f"is not supported. Supported models: {SupportedModels.values}"
23+
f"is not supported. Supported models: {SupportedModels.values()}"
2324
)
2425

2526

@@ -32,7 +33,7 @@ class ForecastOperatorModelFactory:
3233
SupportedModels.Prophet: ProphetOperatorModel,
3334
SupportedModels.Arima: ArimaOperatorModel,
3435
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
35-
# SupportedModels.LGBForecast: MLForecastOperatorModel,
36+
SupportedModels.LGBForecast: MLForecastOperatorModel,
3637
SupportedModels.AutoMLX: AutoMLXOperatorModel,
3738
SupportedModels.AutoTS: AutoTSOperatorModel,
3839
}

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def _build_model(self) -> pd.DataFrame:
142142
dt_column=self.spec.datetime_column.name,
143143
)
144144

145+
# if os.environ["OCI__IS_SPARK"]:
146+
# pass
147+
# else:
145148
Parallel(n_jobs=-1, require="sharedmem")(
146149
delayed(ProphetOperatorModel._train_model)(
147150
self, i, series_id, df, model_kwargs.copy()
@@ -354,7 +357,7 @@ def _generate_report(self):
354357
logger.warn(f"Failed to generate Explanations with error: {e}.")
355358
logger.debug(f"Full Traceback: {traceback.format_exc()}")
356359

357-
model_description = (
360+
model_description = rc.Text(
358361
"Prophet is a procedure for forecasting time series data based on an additive "
359362
"model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
360363
"plus holiday effects. It works best with time series that have strong seasonal "

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ spec:
311311
missing_value_imputation:
312312
type: boolean
313313
required: false
314-
default: false
314+
default: true
315315
outlier_treatment:
316316
type: boolean
317317
required: false

0 commit comments

Comments
 (0)