Skip to content

Commit 42d6911

Browse files
committed
fill na for non-time based ad
1 parent 8f708e9 commit 42d6911

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ def __init__(self, dataset_info, name="historical_data"):
3232
self.dataset_info = dataset_info
3333
self.target_category_columns = dataset_info.target_category_columns
3434
self.target_column_name = dataset_info.target_column
35-
self.dt_column_name = dataset_info.datetime_column.name if dataset_info.datetime_column else None
36-
self.dt_column_format = dataset_info.datetime_column.format if dataset_info.datetime_column else None
35+
self.dt_column_name = (
36+
dataset_info.datetime_column.name if dataset_info.datetime_column else None
37+
)
38+
self.dt_column_format = (
39+
dataset_info.datetime_column.format
40+
if dataset_info.datetime_column
41+
else None
42+
)
3743
self.preprocessing = dataset_info.preprocessing
3844

3945
def run(self, data):
@@ -58,6 +64,7 @@ def run(self, data):
5864
if self.dt_column_name:
5965
clean_df = self._format_datetime_col(clean_df)
6066
clean_df = self._set_multi_index(clean_df)
67+
clean_df = self._fill_na(clean_df) if not self.dt_column_name else clean_df
6168

6269
if self.preprocessing and self.preprocessing.enabled:
6370
if self.name == "historical_data":
@@ -67,7 +74,9 @@ def run(self, data):
6774
except Exception as e:
6875
logger.debug(f"Missing value imputation failed with {e.args}")
6976
else:
70-
logger.info("Skipping missing value imputation because it is disabled")
77+
logger.info(
78+
"Skipping missing value imputation because it is disabled"
79+
)
7180
if self.preprocessing.steps.outlier_treatment:
7281
try:
7382
clean_df = self._outlier_treatment(clean_df)
@@ -78,7 +87,9 @@ def run(self, data):
7887
elif self.name == "additional_data":
7988
clean_df = self._missing_value_imputation_add(clean_df)
8089
else:
81-
logger.info("Skipping all preprocessing steps because preprocessing is disabled")
90+
logger.info(
91+
"Skipping all preprocessing steps because preprocessing is disabled"
92+
)
8293
return clean_df
8394

8495
def _remove_trailing_whitespace(self, df):
@@ -96,7 +107,14 @@ def _set_series_id_column(self, df):
96107
merged_values = df[DataColumns.Series].unique().tolist()
97108
if self.target_category_columns:
98109
for value in merged_values:
99-
self._target_category_columns_map[value] = df[df[DataColumns.Series] == value][self.target_category_columns].drop_duplicates().iloc[0].to_dict()
110+
self._target_category_columns_map[value] = (
111+
df[df[DataColumns.Series] == value][
112+
self.target_category_columns
113+
]
114+
.drop_duplicates()
115+
.iloc[0]
116+
.to_dict()
117+
)
100118

101119
if self.target_category_columns != [DataColumns.Series]:
102120
df = df.drop(self.target_category_columns, axis=1)
@@ -127,7 +145,9 @@ def _set_multi_index(self, df):
127145
"""
128146
if self.dt_column_name:
129147
df = df.set_index([self.dt_column_name, DataColumns.Series])
130-
return df.sort_values([self.dt_column_name, DataColumns.Series], ascending=True)
148+
return df.sort_values(
149+
[self.dt_column_name, DataColumns.Series], ascending=True
150+
)
131151
return df.set_index([df.index, DataColumns.Series])
132152

133153
def _missing_value_imputation_hist(self, df):
@@ -225,5 +245,10 @@ def _check_historical_dataset(self, df):
225245
226246
}
227247
"""
248+
228249
def get_target_category_columns_map(self):
229-
return self._target_category_columns_map
250+
return self._target_category_columns_map
251+
252+
def _fill_na(self, df: pd.DataFrame, na_value=0) -> pd.DataFrame:
253+
"""Fill nans in dataframe"""
254+
return df.fillna(value=na_value)

0 commit comments

Comments
 (0)