Skip to content

Commit a85beec

Browse files
committed
adding additonal test
1 parent 4d2ac61 commit a85beec

File tree

1 file changed

+59
-8
lines changed

1 file changed

+59
-8
lines changed

tests/operators/test_datasets.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from copy import deepcopy
1414
from pathlib import Path
1515
import random
16+
import pathlib
1617

1718

1819
DATASETS_LIST = [
@@ -43,6 +44,15 @@
4344
"WoolyDataset",
4445
]
4546

47+
MODELS = [
48+
"arima",
49+
"automlx",
50+
"prophet",
51+
"neuralprophet",
52+
"autots",
53+
"auto",
54+
]
55+
4656
TEMPLATE_YAML = {
4757
"kind": "operator",
4858
"type": "forecast",
@@ -73,14 +83,9 @@
7383
parameters_short = []
7484

7585
for dataset_i in DATASETS_LIST[2:3]: # + [DATASETS_LIST[-2]]
76-
for model in [
77-
"arima",
78-
"automlx",
79-
"prophet",
80-
"neuralprophet",
81-
"autots",
82-
"auto",
83-
]: # ["arima", "automlx", "prophet", "neuralprophet", "autots", "auto"]
86+
for (
87+
model
88+
) in MODELS: # ["arima", "automlx", "prophet", "neuralprophet", "autots", "auto"]
8489
parameters_short.append((model, dataset_i))
8590

8691

@@ -175,6 +180,52 @@ def test_load_datasets(model, dataset_name):
175180
return test_metrics.iloc[0][f"{columns[0]}_A"]
176181

177182

183+
@pytest.mark.parametrize("model", MODELS)
184+
def test_rossman(model):
185+
curr_dir = pathlib.Path(__file__).parent.resolve()
186+
data_folder = f"{curr_dir}/data/"
187+
historical_data_path = f"{curr_dir}/data/rs_10_prim.csv"
188+
additional_data_path = f"{curr_dir}/data/rs_10_add.csv"
189+
test_data_path = f"{curr_dir}/data/rs_10_test.csv"
190+
191+
with tempfile.TemporaryDirectory() as tmpdirname:
192+
output_data_path = f"{tmpdirname}/results"
193+
yaml_i = deepcopy(TEMPLATE_YAML)
194+
generate_train_metrics = True
195+
196+
yaml_i["spec"]["additional_data"] = {"url": additional_data_path}
197+
yaml_i["spec"]["historical_data"]["url"] = historical_data_path
198+
yaml_i["spec"]["test_data"] = {"url": test_data_path}
199+
yaml_i["spec"]["output_directory"]["url"] = output_data_path
200+
yaml_i["spec"]["model"] = model
201+
yaml_i["spec"]["target_column"] = "Sales"
202+
yaml_i["spec"]["datetime_column"]["name"] = "Date"
203+
yaml_i["spec"]["target_category_columns"] = ["Store"]
204+
yaml_i["spec"]["horizon"] = PERIODS
205+
206+
if generate_train_metrics:
207+
yaml_i["spec"]["generate_metrics"] = generate_train_metrics
208+
if model == "autots":
209+
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
210+
if model == "automlx":
211+
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
212+
213+
forecast_yaml_filename = f"{tmpdirname}/forecast.yaml"
214+
with open(f"{tmpdirname}/forecast.yaml", "w") as f:
215+
f.write(yaml.dump(yaml_i))
216+
sleep(0.1)
217+
subprocess.run(
218+
f"ads operator run -f {forecast_yaml_filename} --debug", shell=True
219+
)
220+
sleep(0.1)
221+
subprocess.run(f"ls -a {output_data_path}", shell=True)
222+
223+
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
224+
print(test_metrics)
225+
train_metrics = pd.read_csv(f"{tmpdirname}/results/metrics.csv")
226+
print(train_metrics)
227+
228+
178229
if __name__ == "__main__":
179230
failed_runs = []
180231
results = dict()

0 commit comments

Comments
 (0)