Skip to content

Commit f17b49b

Browse files
authored
adding more tests (#577)
2 parents d66b885 + 52a0929 commit f17b49b

File tree

10 files changed

+63
-47
lines changed

10 files changed

+63
-47
lines changed

ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..operator_config import AnomalyOperatorSpec
88
from ads.opctl.operator.lowcode.common.utils import (
99
default_signer,
10-
load_data,
1110
merge_category_columns,
1211
)
1312
from ads.opctl.operator.lowcode.common.data import AbstractData

ads/opctl/operator/lowcode/anomaly/model/base_model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ..const import SupportedModels
2424
from ads.opctl.operator.lowcode.common.utils import (
2525
human_time_friendly,
26-
load_data,
2726
enable_print,
2827
disable_print,
2928
write_data,
@@ -325,17 +324,17 @@ def _fallback_build_model(self):
325324
for target, df in self.datasets.full_data_dict.items():
326325
est = linear_model.SGDOneClassSVM(random_state=42)
327326
est.fit(df[target].values.reshape(-1, 1))
328-
y_pred = np.vectorize(self.outlier_map.get)(est.predict(df[target].values.reshape(-1, 1)))
327+
y_pred = np.vectorize(self.outlier_map.get)(
328+
est.predict(df[target].values.reshape(-1, 1))
329+
)
329330
scores = est.score_samples(df[target].values.reshape(-1, 1))
330331

331-
anomaly = pd.DataFrame({
332-
date_column: df[date_column],
333-
OutputColumns.ANOMALY_COL: y_pred
334-
}).reset_index(drop=True)
335-
score = pd.DataFrame({
336-
date_column: df[date_column],
337-
OutputColumns.SCORE_COL: scores
338-
}).reset_index(drop=True)
332+
anomaly = pd.DataFrame(
333+
{date_column: df[date_column], OutputColumns.ANOMALY_COL: y_pred}
334+
).reset_index(drop=True)
335+
score = pd.DataFrame(
336+
{date_column: df[date_column], OutputColumns.SCORE_COL: scores}
337+
).reset_index(drop=True)
339338
anomaly_output.add_output(target, anomaly, score)
340339

341340
return anomaly_output

ads/opctl/operator/lowcode/common/data.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,7 @@ def get_data_for_series(self, series_id):
5252
def _load_data(self, data_spec, **kwargs):
5353
loading_start_time = time.time()
5454
try:
55-
raw_data = load_data(
56-
filename=data_spec.url,
57-
format=data_spec.format,
58-
columns=data_spec.columns,
59-
connect_args=data_spec.connect_args,
60-
sql=data_spec.sql,
61-
table_name=data_spec.table_name,
62-
limit=data_spec.limit,
63-
)
55+
raw_data = load_data(data_spec)
6456
except InvalidParameterError as e:
6557
e.args = e.args + (f"Invalid Parameter: {self.name}",)
6658
raise e

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DataMismatchError,
2828
)
2929
from ads.opctl.operator.common.operator_config import OutputDirectory
30+
from ads.common.object_storage_details import ObjectStorageDetails
3031

3132

3233
def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
@@ -42,17 +43,21 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
4243
return pd_fn(filename, storage_options=storage_options, **kwargs)
4344

4445

45-
def load_data(
46-
filename=None,
47-
format=None,
48-
storage_options=None,
49-
columns=None,
50-
connect_args=None,
51-
sql=None,
52-
table_name=None,
53-
limit=None,
54-
**kwargs,
55-
):
46+
def load_data(data_spec, storage_options=None, **kwargs):
47+
if data_spec is None:
48+
raise InvalidParameterError(f"No details provided for this data source.")
49+
filename = data_spec.url
50+
format = data_spec.format
51+
columns = data_spec.columns
52+
connect_args = data_spec.connect_args
53+
sql = data_spec.sql
54+
table_name = data_spec.table_name
55+
limit = data_spec.limit
56+
57+
storage_options = storage_options or (
58+
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
59+
)
60+
5661
if filename is not None:
5762
if not format:
5863
_, format = os.path.splitext(filename)

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def _generate_report(self):
171171
self.formatted_global_explanation = (
172172
global_explanation_df / global_explanation_df.sum(axis=0) * 100
173173
)
174+
self.formatted_global_explanation = (
175+
self.formatted_global_explanation.rename(
176+
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
177+
axis=1,
178+
)
179+
)
174180

175181
# Create a markdown section for the global explainability
176182
global_explanation_section = dp.Blocks(

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ def _generate_report(self):
241241
self.formatted_global_explanation = (
242242
global_explanation_df / global_explanation_df.sum(axis=0) * 100
243243
)
244+
self.formatted_global_explanation = (
245+
self.formatted_global_explanation.rename(
246+
{self.spec.datetime_column.name: ForecastOutputColumns.DATE}, axis=1
247+
)
248+
)
244249

245250
# Create a markdown section for the global explainability
246251
global_explanation_section = dp.Blocks(

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ..const import ForecastOutputColumns, PROPHET_INTERNAL_DATE_COL
1414
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.opctl.operator.lowcode.common.utils import (
16-
load_data,
1716
get_frequency_in_seconds,
1817
get_frequency_of_datetime,
1918
)

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,15 +440,17 @@ def explain_model(self):
440440
for s_id, expl_df in self.explanations_info.items():
441441
expl_df = expl_df.rename(rename_cols, axis=1)
442442
# Local Expl
443-
self.local_explanation[s_id] = self.get_horizon(expl_df)
443+
self.local_explanation[s_id] = self.get_horizon(expl_df).drop(
444+
["future_regressors_additive"], axis=1
445+
)
444446
self.local_explanation[s_id]["Series"] = s_id
445-
447+
self.local_explanation[s_id].index.rename(self.dt_column_name, inplace=True)
446448
# Global Expl
447449
g_expl = self.drop_horizon(expl_df).mean()
448450
g_expl.name = s_id
449451
global_expl.append(g_expl)
450452
self.global_explanation = pd.concat(global_expl, axis=1)
451-
self.formatted_global_explanation = self.global_explanation.drop(
453+
self.global_explanation = self.global_explanation.drop(
452454
index=["future_regressors_additive"], axis=0
453455
)
454456
self.formatted_global_explanation = (

tests/operators/forecast/test_datasets.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@
8989
parameters_short.append((model, dataset_i))
9090

9191

92+
def verify_explanations(global_fn, local_fn, yaml_i, additional_cols):
93+
glb_expl = pd.read_csv(global_fn, index_col=0)
94+
loc_expl = pd.read_csv(local_fn)
95+
assert loc_expl.shape[0] == PERIODS
96+
for x in [yaml_i["spec"]["datetime_column"]["name"], "Series"]:
97+
assert x in set(loc_expl.columns)
98+
for x in additional_cols:
99+
assert x in set(loc_expl.columns)
100+
assert x in set(glb_expl.index)
101+
assert "Series 1" in set(glb_expl.columns)
102+
103+
92104
@pytest.mark.parametrize("model, dataset_name", parameters_short)
93105
def test_load_datasets(model, dataset_name):
94106
if model == "automlx" and dataset_name == "WeatherDataset":
@@ -97,6 +109,7 @@ def test_load_datasets(model, dataset_name):
97109
datetime_col = dataset_i.time_index.name
98110

99111
columns = dataset_i.components
112+
additional_cols = []
100113
target = dataset_i[columns[0]][:-PERIODS]
101114
test = dataset_i[columns[0]][-PERIODS:]
102115

@@ -145,7 +158,7 @@ def test_load_datasets(model, dataset_name):
145158
yaml_i["spec"]["target_column"] = columns[0]
146159
yaml_i["spec"]["datetime_column"]["name"] = datetime_col
147160
yaml_i["spec"]["horizon"] = PERIODS
148-
if yaml_i["spec"].get("additional_data") is not None and model != "automlx":
161+
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
149162
yaml_i["spec"]["generate_explanations"] = True
150163
if generate_train_metrics:
151164
yaml_i["spec"]["generate_metrics"] = generate_train_metrics
@@ -164,11 +177,13 @@ def test_load_datasets(model, dataset_name):
164177
# sleep(0.1)
165178
run(yaml_i, backend="operator.local", debug=False)
166179
subprocess.run(f"ls -a {output_data_path}", shell=True)
167-
if yaml_i["spec"]["generate_explanations"] and model != "autots":
168-
glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv")
169-
print(glb_expl)
170-
loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
171-
print(loc_expl)
180+
if yaml_i["spec"]["generate_explanations"]:
181+
verify_explanations(
182+
global_fn=f"{tmpdirname}/results/global_explanation.csv",
183+
local_fn=f"{tmpdirname}/results/local_explanation.csv",
184+
yaml_i=yaml_i,
185+
additional_cols=additional_cols,
186+
)
172187

173188
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
174189
print(test_metrics)

tests/operators/forecast/test_errors.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,6 @@ def test_historical_data(operator_setup, model):
289289
tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path
290290
)
291291

292-
yaml_i["spec"]["historical_data"] = None
293-
with pytest.raises(InvalidParameterError):
294-
run_yaml(
295-
tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path
296-
)
297-
298292
yaml_i["spec"].pop("historical_data")
299293
yaml_i["spec"]["TEST"] = historical_data
300294
with pytest.raises(InvalidParameterError):

0 commit comments

Comments
 (0)