Skip to content

Commit f56b5ff

Browse files
committed
adding more tests
1 parent f939cbc commit f56b5ff

File tree

10 files changed

+66
-48
lines changed

10 files changed

+66
-48
lines changed

ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..operator_config import AnomalyOperatorSpec
88
from ads.opctl.operator.lowcode.common.utils import (
99
default_signer,
10-
load_data,
1110
merge_category_columns,
1211
)
1312
from ads.opctl.operator.lowcode.common.data import AbstractData

ads/opctl/operator/lowcode/anomaly/model/base_model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ..const import SupportedModels
2424
from ads.opctl.operator.lowcode.common.utils import (
2525
human_time_friendly,
26-
load_data,
2726
enable_print,
2827
disable_print,
2928
write_data,
@@ -325,17 +324,17 @@ def _fallback_build_model(self):
325324
for target, df in self.datasets.full_data_dict.items():
326325
est = linear_model.SGDOneClassSVM(random_state=42)
327326
est.fit(df[target].values.reshape(-1, 1))
328-
y_pred = np.vectorize(self.outlier_map.get)(est.predict(df[target].values.reshape(-1, 1)))
327+
y_pred = np.vectorize(self.outlier_map.get)(
328+
est.predict(df[target].values.reshape(-1, 1))
329+
)
329330
scores = est.score_samples(df[target].values.reshape(-1, 1))
330331

331-
anomaly = pd.DataFrame({
332-
date_column: df[date_column],
333-
OutputColumns.ANOMALY_COL: y_pred
334-
}).reset_index(drop=True)
335-
score = pd.DataFrame({
336-
date_column: df[date_column],
337-
OutputColumns.SCORE_COL: scores
338-
}).reset_index(drop=True)
332+
anomaly = pd.DataFrame(
333+
{date_column: df[date_column], OutputColumns.ANOMALY_COL: y_pred}
334+
).reset_index(drop=True)
335+
score = pd.DataFrame(
336+
{date_column: df[date_column], OutputColumns.SCORE_COL: scores}
337+
).reset_index(drop=True)
339338
anomaly_output.add_output(target, anomaly, score)
340339

341340
return anomaly_output

ads/opctl/operator/lowcode/common/data.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,7 @@ def get_data_for_series(self, series_id):
5252
def _load_data(self, data_spec, **kwargs):
5353
loading_start_time = time.time()
5454
try:
55-
raw_data = load_data(
56-
filename=data_spec.url,
57-
format=data_spec.format,
58-
columns=data_spec.columns,
59-
connect_args=data_spec.connect_args,
60-
sql=data_spec.sql,
61-
table_name=data_spec.table_name,
62-
limit=data_spec.limit,
63-
)
55+
raw_data = load_data(data_spec)
6456
except InvalidParameterError as e:
6557
e.args = e.args + (f"Invalid Parameter: {self.name}",)
6658
raise e

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DataMismatchError,
2828
)
2929
from ads.opctl.operator.common.operator_config import OutputDirectory
30+
from ads.common.object_storage_details import ObjectStorageDetails
3031

3132

3233
def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
@@ -42,17 +43,21 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
4243
return pd_fn(filename, storage_options=storage_options, **kwargs)
4344

4445

45-
def load_data(
46-
filename=None,
47-
format=None,
48-
storage_options=None,
49-
columns=None,
50-
connect_args=None,
51-
sql=None,
52-
table_name=None,
53-
limit=None,
54-
**kwargs,
55-
):
46+
def load_data(data_spec, storage_options=None, **kwargs):
47+
if data_spec is None:
48+
raise InvalidParameterError(f"No details provided for this data source.")
49+
filename = data_spec.url
50+
format = data_spec.format
51+
columns = data_spec.columns
52+
connect_args = data_spec.connect_args
53+
sql = data_spec.sql
54+
table_name = data_spec.table_name
55+
limit = data_spec.limit
56+
57+
storage_options = storage_options or (
58+
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
59+
)
60+
5661
if filename is not None:
5762
if not format:
5863
_, format = os.path.splitext(filename)

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def _generate_report(self):
171171
self.formatted_global_explanation = (
172172
global_explanation_df / global_explanation_df.sum(axis=0) * 100
173173
)
174+
self.formatted_global_explanation = (
175+
self.formatted_global_explanation.rename(
176+
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
177+
axis=1,
178+
)
179+
)
174180

175181
# Create a markdown section for the global explainability
176182
global_explanation_section = dp.Blocks(

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ def _generate_report(self):
241241
self.formatted_global_explanation = (
242242
global_explanation_df / global_explanation_df.sum(axis=0) * 100
243243
)
244+
self.formatted_global_explanation = (
245+
self.formatted_global_explanation.rename(
246+
{self.spec.datetime_column.name: ForecastOutputColumns.DATE}, axis=1
247+
)
248+
)
244249

245250
# Create a markdown section for the global explainability
246251
global_explanation_section = dp.Blocks(

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
SupportedModels,
4646
SpeedAccuracyMode,
4747
)
48+
from ..const import ForecastOutputColumns
4849
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
4950
from ads.common.decorator.runtime_dependency import runtime_dependency
5051
from .forecast_datasets import ForecastDatasets, ForecastOutput
@@ -710,6 +711,9 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
710711
local_kernel_explnr_df = pd.DataFrame(
711712
local_kernel_explnr_vals, columns=data.columns
712713
)
714+
local_kernel_explnr_df = local_kernel_explnr_df.rename(
715+
{datetime_col_name: ForecastOutputColumns.DATE}, axis=0
716+
)
713717
self.local_explanation[series_id] = local_kernel_explnr_df
714718

715719
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ..const import ForecastOutputColumns, PROPHET_INTERNAL_DATE_COL
1414
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.opctl.operator.lowcode.common.utils import (
16-
load_data,
1716
get_frequency_in_seconds,
1817
get_frequency_of_datetime,
1918
)

tests/operators/forecast/test_datasets.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
DATASETS_LIST = [
22-
"AirPassengersDataset",
23-
"AusBeerDataset",
22+
# "AirPassengersDataset",
23+
# "AusBeerDataset",
2424
"AustralianTourismDataset",
2525
# # "ETTh1Dataset",
2626
# # "ETTh2Dataset",
@@ -34,7 +34,7 @@
3434
# # "ILINetDataset",
3535
# # "IceCreamHeaterDataset",
3636
# # "MonthlyMilkDataset",
37-
"MonthlyMilkIncompleteDataset",
37+
# "MonthlyMilkIncompleteDataset",
3838
# # "SunspotsDataset",
3939
# # "TaylorDataset",
4040
# # "TemperatureDataset",
@@ -43,7 +43,7 @@
4343
# "UberTLCDataset",
4444
# "WeatherDataset",
4545
# "WineDataset",
46-
"WoolyDataset",
46+
# "WoolyDataset",
4747
]
4848

4949
MODELS = [
@@ -89,6 +89,18 @@
8989
parameters_short.append((model, dataset_i))
9090

9191

92+
def verify_explanations(global_fn, local_fn, yaml_i, additional_cols):
93+
glb_expl = pd.read_csv(global_fn, index_col=0)
94+
loc_expl = pd.read_csv(local_fn)
95+
assert loc_expl.shape[0] == PERIODS
96+
for x in ["Date", "Series"]:
97+
assert x in set(loc_expl.columns)
98+
for x in additional_cols:
99+
assert x in set(loc_expl.columns)
100+
assert x in set(glb_expl.index)
101+
assert "Series 1" in set(glb_expl.columns)
102+
103+
92104
@pytest.mark.parametrize("model, dataset_name", parameters_short)
93105
def test_load_datasets(model, dataset_name):
94106
if model == "automlx" and dataset_name == "WeatherDataset":
@@ -97,6 +109,7 @@ def test_load_datasets(model, dataset_name):
97109
datetime_col = dataset_i.time_index.name
98110

99111
columns = dataset_i.components
112+
additional_cols = []
100113
target = dataset_i[columns[0]][:-PERIODS]
101114
test = dataset_i[columns[0]][-PERIODS:]
102115

@@ -145,7 +158,7 @@ def test_load_datasets(model, dataset_name):
145158
yaml_i["spec"]["target_column"] = columns[0]
146159
yaml_i["spec"]["datetime_column"]["name"] = datetime_col
147160
yaml_i["spec"]["horizon"] = PERIODS
148-
if yaml_i["spec"].get("additional_data") is not None and model != "automlx":
161+
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
149162
yaml_i["spec"]["generate_explanations"] = True
150163
if generate_train_metrics:
151164
yaml_i["spec"]["generate_metrics"] = generate_train_metrics
@@ -164,11 +177,13 @@ def test_load_datasets(model, dataset_name):
164177
# sleep(0.1)
165178
run(yaml_i, backend="operator.local", debug=False)
166179
subprocess.run(f"ls -a {output_data_path}", shell=True)
167-
if yaml_i["spec"]["generate_explanations"] and model != "autots":
168-
glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv")
169-
print(glb_expl)
170-
loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
171-
print(loc_expl)
180+
if yaml_i["spec"]["generate_explanations"]:
181+
verify_explanations(
182+
global_fn=f"{tmpdirname}/results/global_explanation.csv",
183+
local_fn=f"{tmpdirname}/results/local_explanation.csv",
184+
yaml_i=yaml_i,
185+
additional_cols=additional_cols,
186+
)
172187

173188
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
174189
print(test_metrics)

tests/operators/forecast/test_errors.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,6 @@ def test_historical_data(operator_setup, model):
289289
tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path
290290
)
291291

292-
yaml_i["spec"]["historical_data"] = None
293-
with pytest.raises(InvalidParameterError):
294-
run_yaml(
295-
tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path
296-
)
297-
298292
yaml_i["spec"].pop("historical_data")
299293
yaml_i["spec"]["TEST"] = historical_data
300294
with pytest.raises(InvalidParameterError):

0 commit comments

Comments
 (0)