Skip to content

Commit 72f2e2d

Browse files
committed
use all models in list
1 parent a2bb627 commit 72f2e2d

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

tests/operators/forecast/test_datasets.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131

3232
MODELS = [
3333
"arima",
34-
# "automlx",
34+
"automlx",
3535
"prophet",
3636
"neuralprophet",
37-
# "autots",
37+
"autots",
3838
# "lgbforecast",
3939
"auto-select",
4040
]
@@ -156,7 +156,7 @@ def test_load_datasets(model, data_details):
156156
verify_explanations(
157157
tmpdirname=tmpdirname,
158158
additional_cols=additional_cols,
159-
target_category_columns=yaml_i["spec"]['target_category_columns']
159+
target_category_columns=yaml_i["spec"]["target_category_columns"],
160160
)
161161
if include_test_data:
162162
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
@@ -165,7 +165,7 @@ def test_load_datasets(model, data_details):
165165
print(train_metrics)
166166

167167

168-
@pytest.mark.parametrize("model", MODELS[:1])
168+
@pytest.mark.parametrize("model", MODELS[:-1])
169169
def test_pandas_to_historical(model):
170170
df = pd.read_csv(f"{DATASET_PREFIX}dataset1.csv")
171171

@@ -184,7 +184,7 @@ def test_pandas_to_historical(model):
184184
check_output_for_errors(output_data_path)
185185

186186

187-
@pytest.mark.parametrize("model", ["neuralprophet"])
187+
@pytest.mark.parametrize("model", MODELS[:-1])
188188
def test_pandas_to_historical_test(model):
189189
df = pd.read_csv(f"{DATASET_PREFIX}dataset4.csv")
190190
df_train = df[:-PERIODS]
@@ -207,26 +207,33 @@ def test_pandas_to_historical_test(model):
207207
test_metrics = pd.read_csv(f"{output_data_path}/metrics.csv")
208208
print(test_metrics)
209209

210+
210211
def check_output_for_errors(output_data_path):
211212
# try:
212213
# List files in the directory
213-
result = subprocess.run(f"ls -a {output_data_path}", shell=True, check=True, text=True, capture_output=True)
214+
result = subprocess.run(
215+
f"ls -a {output_data_path}",
216+
shell=True,
217+
check=True,
218+
text=True,
219+
capture_output=True,
220+
)
214221
files = result.stdout.splitlines()
215222

216223
# Check if errors.json is in the directory
217224
if "errors.json" in files:
218225
errors_file_path = os.path.join(output_data_path, "errors.json")
219-
226+
220227
# Read the errors.json file
221228
with open(errors_file_path, "r") as f:
222229
errors_content = json.load(f)
223-
230+
224231
# Extract and raise the error message
225232
# error_message = errors_content.get("message", "An error occurred.")
226233
raise Exception(errors_content)
227234

228235
print("No errors.json file found. Directory is clear.")
229-
236+
230237
# except subprocess.CalledProcessError as e:
231238
# print(f"Error listing files in directory: {e}")
232239
# except FileNotFoundError:
@@ -236,6 +243,7 @@ def check_output_for_errors(output_data_path):
236243
# except Exception as e:
237244
# print(f"Raised error: {e}")
238245

246+
239247
def run_operator(
240248
historical_data_path,
241249
additional_data_path,

0 commit comments

Comments
 (0)