Skip to content

Commit 5225e65

Browse files
committed
clean up error messaging
1 parent ebf73e7 commit 5225e65

File tree

6 files changed

+21
-15
lines changed

6 files changed

+21
-15
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
142142
)
143143

144144

145+
def write_json(json_dict, filename, storage_options=None):
146+
with fsspec.open(filename, mode="w", **storage_options) as f:
147+
f.write(json.dumps(json_dict))
148+
149+
145150
def write_simple_json(data, path):
146151
if ObjectStorageDetails.is_oci_path(path):
147152
storage_options = default_signer()

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
@@ -132,11 +132,12 @@ def _train_model(self, i, s_id, df, model_kwargs):
132132

133133
logger.debug("===========Done===========")
134134
except Exception as e:
135-
self.errors_dict[s_id] = {
135+
new_error = {
136136
"model_name": self.spec.model,
137137
"error": str(e),
138138
"error_trace": traceback.format_exc(),
139139
}
140+
self.errors_dict[s_id] = new_error
140141
logger.warning(f"Encountered Error: {e}. Skipping.")
141142
logger.warning(traceback.format_exc())
142143

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,17 @@ def _build_model(self) -> pd.DataFrame:
184184
"selected_model_params": model.selected_model_params_,
185185
}
186186
except Exception as e:
187-
self.errors_dict[s_id] = {
187+
new_error = {
188188
"model_name": self.spec.model,
189189
"error": str(e),
190190
"error_trace": traceback.format_exc(),
191191
}
192+
if s_id in self.errors_dict:
193+
self.errors_dict[s_id]["model_fitting"] = new_error
194+
else:
195+
self.errors_dict[s_id] = {"model_fitting": new_error}
192196
logger.warning(f"Encountered Error: {e}. Skipping.")
197+
logger.warning(f"self.errors_dict[s_id]: {self.errors_dict[s_id]}")
193198
logger.warning(traceback.format_exc())
194199

195200
logger.debug("===========Forecast Generated===========")

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
merged_category_column_name,
2929
seconds_to_datetime,
3030
write_data,
31+
write_json,
3132
)
3233
from ads.opctl.operator.lowcode.forecast.utils import (
3334
_build_metrics_df,
@@ -634,17 +635,12 @@ def _save_report(
634635
f"The outputs have been successfully generated and placed into the directory: {unique_output_dir}."
635636
)
636637
if self.errors_dict:
637-
write_data(
638-
data=pd.DataFrame(
639-
self.errors_dict, index=np.arange(len(self.errors_dict.keys()))
640-
),
638+
write_json(
639+
json_dict=self.errors_dict,
641640
filename=os.path.join(
642641
unique_output_dir, self.spec.errors_dict_filename
643642
),
644-
format="json",
645643
storage_options=storage_options,
646-
index=True,
647-
indent=4,
648644
)
649645
results.set_errors_dict(self.errors_dict)
650646
else:

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def _build_model(self) -> pd.DataFrame:
229229
self.models = {}
230230
self.trainers = {}
231231
self.outputs = {}
232-
self.errors_dict = {}
233232
self.explanations_info = {}
234233
self.accepted_regressors = {}
235234
self.additional_regressors = self.datasets.get_additional_data_column_names()

tests/operators/forecast/test_errors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def setup_rossman():
255255

256256
def setup_faulty_rossman():
257257
curr_dir = pathlib.Path(__file__).parent.resolve()
258-
data_folder = f"{curr_dir}/../data/"
258+
data_folder = f"{curr_dir}/../data"
259259
historical_data_path = f"{data_folder}/rs_2_prim.csv"
260260
additional_data_path = f"{data_folder}/rs_2_add_encoded.csv"
261261
return historical_data_path, additional_data_path
@@ -707,10 +707,10 @@ def test_arima_automlx_errors(operator_setup, model):
707707
error_content = json.load(error_file)
708708
assert (
709709
"Input data does not have a consistent (in terms of diff) DatetimeIndex."
710-
in error_content["13"]["error"]
711-
), "Error message mismatch"
710+
in error_content["13"]["model_fitting"]["error"]
711+
), f"Error message mismatch: {error_content}"
712712

713-
if model not in ["autots"]: # , "lgbforecast"
713+
if model not in ["autots", "automlx"]: # , "lgbforecast"
714714
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
715715
global_fn = f"{tmpdirname}/results/global_explanation.csv"
716716
assert os.path.exists(

0 commit comments

Comments
 (0)