@@ -32,8 +32,14 @@ def __init__(self, dataset_info, name="historical_data"):
32
32
self .dataset_info = dataset_info
33
33
self .target_category_columns = dataset_info .target_category_columns
34
34
self .target_column_name = dataset_info .target_column
35
- self .dt_column_name = dataset_info .datetime_column .name if dataset_info .datetime_column else None
36
- self .dt_column_format = dataset_info .datetime_column .format if dataset_info .datetime_column else None
35
+ self .dt_column_name = (
36
+ dataset_info .datetime_column .name if dataset_info .datetime_column else None
37
+ )
38
+ self .dt_column_format = (
39
+ dataset_info .datetime_column .format
40
+ if dataset_info .datetime_column
41
+ else None
42
+ )
37
43
self .preprocessing = dataset_info .preprocessing
38
44
39
45
def run (self , data ):
@@ -58,6 +64,7 @@ def run(self, data):
58
64
if self .dt_column_name :
59
65
clean_df = self ._format_datetime_col (clean_df )
60
66
clean_df = self ._set_multi_index (clean_df )
67
+ clean_df = self ._fill_na (clean_df ) if not self .dt_column_name else clean_df
61
68
62
69
if self .preprocessing and self .preprocessing .enabled :
63
70
if self .name == "historical_data" :
@@ -67,7 +74,9 @@ def run(self, data):
67
74
except Exception as e :
68
75
logger .debug (f"Missing value imputation failed with { e .args } " )
69
76
else :
70
- logger .info ("Skipping missing value imputation because it is disabled" )
77
+ logger .info (
78
+ "Skipping missing value imputation because it is disabled"
79
+ )
71
80
if self .preprocessing .steps .outlier_treatment :
72
81
try :
73
82
clean_df = self ._outlier_treatment (clean_df )
@@ -78,7 +87,9 @@ def run(self, data):
78
87
elif self .name == "additional_data" :
79
88
clean_df = self ._missing_value_imputation_add (clean_df )
80
89
else :
81
- logger .info ("Skipping all preprocessing steps because preprocessing is disabled" )
90
+ logger .info (
91
+ "Skipping all preprocessing steps because preprocessing is disabled"
92
+ )
82
93
return clean_df
83
94
84
95
def _remove_trailing_whitespace (self , df ):
@@ -96,7 +107,14 @@ def _set_series_id_column(self, df):
96
107
merged_values = df [DataColumns .Series ].unique ().tolist ()
97
108
if self .target_category_columns :
98
109
for value in merged_values :
99
- self ._target_category_columns_map [value ] = df [df [DataColumns .Series ] == value ][self .target_category_columns ].drop_duplicates ().iloc [0 ].to_dict ()
110
+ self ._target_category_columns_map [value ] = (
111
+ df [df [DataColumns .Series ] == value ][
112
+ self .target_category_columns
113
+ ]
114
+ .drop_duplicates ()
115
+ .iloc [0 ]
116
+ .to_dict ()
117
+ )
100
118
101
119
if self .target_category_columns != [DataColumns .Series ]:
102
120
df = df .drop (self .target_category_columns , axis = 1 )
@@ -127,7 +145,9 @@ def _set_multi_index(self, df):
127
145
"""
128
146
if self .dt_column_name :
129
147
df = df .set_index ([self .dt_column_name , DataColumns .Series ])
130
- return df .sort_values ([self .dt_column_name , DataColumns .Series ], ascending = True )
148
+ return df .sort_values (
149
+ [self .dt_column_name , DataColumns .Series ], ascending = True
150
+ )
131
151
return df .set_index ([df .index , DataColumns .Series ])
132
152
133
153
def _missing_value_imputation_hist (self , df ):
@@ -225,5 +245,10 @@ def _check_historical_dataset(self, df):
225
245
226
246
}
227
247
"""
248
+
228
249
def get_target_category_columns_map (self ):
229
- return self ._target_category_columns_map
250
+ return self ._target_category_columns_map
251
+
252
+ def _fill_na (self , df : pd .DataFrame , na_value = 0 ) -> pd .DataFrame :
253
+ """Fill nans in dataframe"""
254
+ return df .fillna (value = na_value )
0 commit comments