Skip to content

Commit 3838d2a

Browse files
committed
add unit test
1 parent 8149235 commit 3838d2a

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from abc import ABC
77

8+
import numpy as np
89
import pandas as pd
910

1011
from ads.opctl import logger
@@ -215,12 +216,17 @@ def _outlier_treatment(self, df):
215216
.transform(lambda x: (x - x.mean()) / x.std())
216217
)
217218
outliers_mask = df["z_score"].abs() > 3
219+
220+
if df[self.target_column_name].dtype == np.int:
221+
df[self.target_column_name].astype(np.float)
222+
218223
df.loc[outliers_mask, self.target_column_name] = (
219224
df[self.target_column_name]
220225
.groupby(DataColumns.Series)
221-
.transform(lambda x: x.mean())
226+
.transform(lambda x: np.median(x))
222227
)
223-
return df.drop("z_score", axis=1)
228+
df_ret = df.drop("z_score", axis=1)
229+
return df_ret
224230

225231
def _check_historical_dataset(self, df):
226232
expected_names = [self.target_column_name, self.dt_column_name] + (

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
@@ -52,7 +52,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
5252
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
5353
)
5454
if vault_secret_id is not None and connect_args is None:
55-
connect_args = dict()
55+
connect_args = {}
5656

5757
if data is not None:
5858
if format == "spark":
@@ -102,7 +102,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
102102
except Exception as e:
103103
raise Exception(
104104
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
105-
)
105+
) from e
106106

107107
con = oracledb.connect(**connect_args)
108108
if table_name is not None:
@@ -126,6 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
126126

127127

128128
def write_data(data, filename, format, storage_options, index=False, **kwargs):
129+
disable_print()
129130
if not format:
130131
_, format = os.path.splitext(filename)
131132
format = format[1:]
@@ -134,7 +135,8 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
134135
return call_pandas_fsspec(
135136
write_fn, filename, index=index, storage_options=storage_options, **kwargs
136137
)
137-
raise OperatorYamlContentError(
138+
enable_print()
139+
raise InvalidParameterError(
138140
f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
139141
)
140142

tests/operators/forecast/test_datasets.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
import os
66
import yaml
@@ -159,6 +159,39 @@ def test_load_datasets(model, data_details):
159159
print(train_metrics)
160160

161161

162+
@pytest.mark.parametrize("model", MODELS[:-1])
163+
def test_pandas_df_historical(model):
164+
df = pd.read_csv(f"{DATASET_PREFIX}dataset1.csv")
165+
166+
yaml_i = deepcopy(TEMPLATE_YAML)
167+
yaml_i["spec"]["model"] = model
168+
yaml_i["spec"]["historical_data"].pop("url")
169+
yaml_i["spec"]["historical_data"]["data"] = df
170+
yaml_i["spec"]["target_column"] = "Y"
171+
yaml_i["spec"]["datetime_column"]["name"] = DATETIME_COL
172+
yaml_i["spec"]["horizon"] = 5
173+
run(yaml_i, backend="operator.local", debug=False)
174+
subprocess.run(f"ls -a {output_data_path}", shell=True)
175+
176+
177+
@pytest.mark.parametrize("model", MODELS[:-1])
178+
def test_pandas_historical_test(model):
179+
df = pd.read_csv(f"{DATASET_PREFIX}dataset4.csv")
180+
df_train = df[:-1]
181+
df_test = df[-1:]
182+
183+
yaml_i = deepcopy(TEMPLATE_YAML)
184+
yaml_i["spec"]["model"] = model
185+
yaml_i["spec"]["historical_data"].pop("url")
186+
yaml_i["spec"]["historical_data"]["data"] = df_train
187+
yaml_i["spec"]["test_data"]["data"] = df_test
188+
yaml_i["spec"]["target_column"] = "Y"
189+
yaml_i["spec"]["datetime_column"]["name"] = DATETIME_COL
190+
yaml_i["spec"]["horizon"] = 5
191+
run(yaml_i, backend="operator.local", debug=False)
192+
subprocess.run(f"ls -a {output_data_path}", shell=True)
193+
194+
162195
def run_operator(
163196
historical_data_path,
164197
additional_data_path,

0 commit comments

Comments
 (0)