Skip to content

Commit 1db69ab

Browse files
committed
Updated forecast operator schema for preprocessing
1 parent d209cc1 commit 1db69ab

File tree

4 files changed

+59
-21
lines changed

4 files changed

+59
-21
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,26 @@ def run(self, data):
5858
clean_df = self._format_datetime_col(clean_df)
5959
clean_df = self._set_multi_index(clean_df)
6060

61-
if self.name == "historical_data":
62-
try:
63-
clean_df = self._missing_value_imputation_hist(clean_df)
64-
except Exception as e:
65-
logger.debug(f"Missing value imputation failed with {e.args}")
66-
if self.preprocessing:
67-
try:
68-
clean_df = self._outlier_treatment(clean_df)
69-
except Exception as e:
70-
logger.debug(f"Outlier Treatment failed with {e.args}")
71-
else:
72-
logger.debug("Skipping outlier treatment as preprocessing is disabled")
73-
elif self.name == "additional_data":
74-
clean_df = self._missing_value_imputation_add(clean_df)
61+
if self.preprocessing.enabled:
62+
if self.name == "historical_data":
63+
if self.preprocessing.steps.missing_value_imputation:
64+
try:
65+
clean_df = self._missing_value_imputation_hist(clean_df)
66+
except Exception as e:
67+
logger.debug(f"Missing value imputation failed with {e.args}")
68+
else:
69+
logger.info("Skipping missing value imputation because it is disabled")
70+
if self.preprocessing.steps.outlier_treatment:
71+
try:
72+
clean_df = self._outlier_treatment(clean_df)
73+
except Exception as e:
74+
logger.debug(f"Outlier Treatment failed with {e.args}")
75+
else:
76+
logger.info("Skipping outlier treatment because it is disabled")
77+
elif self.name == "additional_data":
78+
clean_df = self._missing_value_imputation_add(clean_df)
79+
else:
80+
logger.info("Skipping all preprocessing steps because preprocessing is disabled")
7581
return clean_df
7682

7783
def _remove_trailing_whitespace(self, df):

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ class DateTimeColumn(DataClassSerializable):
2929
format: str = None
3030

3131

32+
@dataclass(repr=True)
33+
class PreprocessingSteps(DataClassSerializable):
34+
"""Class representing operator specification preprocessing details."""
35+
36+
missing_value_imputation: bool = True
37+
outlier_treatment: bool = True
38+
39+
40+
@dataclass(repr=True)
41+
class DataPreprocessor(DataClassSerializable):
42+
"""Class representing operator specification preprocessing details."""
43+
44+
enabled: bool = True
45+
steps: PreprocessingSteps = field(default_factory=PreprocessingSteps)
46+
47+
3248
@dataclass(repr=True)
3349
class Tuning(DataClassSerializable):
3450
"""Class representing operator specification tuning details."""
@@ -54,7 +70,7 @@ class ForecastOperatorSpec(DataClassSerializable):
5470
global_explanation_filename: str = None
5571
local_explanation_filename: str = None
5672
target_column: str = None
57-
preprocessing: bool = None
73+
preprocessing: DataPreprocessor = field(default_factory=DataPreprocessor)
5874
datetime_column: DateTimeColumn = field(default_factory=DateTimeColumn)
5975
target_category_columns: List[str] = field(default_factory=list)
6076
generate_report: bool = None
@@ -79,7 +95,7 @@ def __post_init__(self):
7995
self.confidence_interval_width = self.confidence_interval_width or 0.80
8096
self.report_filename = self.report_filename or "report.html"
8197
self.preprocessing = (
82-
self.preprocessing if self.preprocessing is not None else True
98+
self.preprocessing if self.preprocessing is not None else DataPreprocessor(enabled=True)
8399
)
84100
# For Report Generation. When user doesn't specify defaults to True
85101
self.generate_report = (

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,23 @@ spec:
286286
default: target
287287

288288
preprocessing:
289-
type: boolean
289+
type: dict
290290
required: false
291-
default: true
292-
meta:
293-
description: "preprocessing and feature engineering can be disabled using this flag, Defaults to true"
291+
schema:
292+
enabled:
293+
type: boolean
294+
required: false
295+
default: true
296+
meta:
297+
description: "preprocessing and feature engineering can be disabled using this flag, Defaults to true"
298+
missing_value_imputation:
299+
type: boolean
300+
required: false
301+
default: true
302+
outlier_detection:
303+
type: boolean
304+
required: false
305+
default: true
294306

295307
generate_explanations:
296308
type: boolean

tests/operators/forecast/test_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def populate_yaml(
185185
additional_data_path=None,
186186
test_data_path=None,
187187
output_data_path=None,
188+
preprocessing=None,
188189
):
189190
if historical_data_path is None:
190191
historical_data_path, additional_data_path, test_data_path = setup_rossman()
@@ -204,7 +205,7 @@ def populate_yaml(
204205
yaml_i["spec"]["datetime_column"]["name"] = "Date"
205206
yaml_i["spec"]["target_category_columns"] = ["Store"]
206207
yaml_i["spec"]["horizon"] = HORIZON
207-
208+
yaml_i["spec"]["preprocessing"] = preprocessing
208209
if generate_train_metrics:
209210
yaml_i["spec"]["generate_metrics"] = generate_train_metrics
210211
if model == "autots":
@@ -372,6 +373,7 @@ def test_0_series(operator_setup, model):
372373
historical_data_path=historical_data_path,
373374
additional_data_path=additional_data_path,
374375
test_data_path=test_data_path,
376+
preprocessing={"enabled": False}
375377
)
376378
with pytest.raises(DataMismatchError):
377379
run_yaml(
@@ -454,12 +456,14 @@ def split_df(df):
454456
historical_data_path, additional_data_path, test_data_path = setup_artificial_data(
455457
tmpdirname, hist_data, add_data, test_data
456458
)
459+
preprocessing_steps = {"missing_value_imputation": True, "outlier_detection": False}
457460
yaml_i, output_data_path = populate_yaml(
458461
tmpdirname=tmpdirname,
459462
model=model,
460463
historical_data_path=historical_data_path,
461464
additional_data_path=additional_data_path,
462465
test_data_path=test_data_path,
466+
preprocessing={"enabled": True, "steps": preprocessing_steps}
463467
)
464468
with pytest.raises(DataMismatchError):
465469
# 4 columns in historical data, but only 1 cat col specified

0 commit comments

Comments
 (0)