Skip to content

Commit 5208701

Browse files
committed
LightGBM requires re-formatting column names
1 parent 96624ed commit 5208701

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

6+
from abc import ABC
7+
8+
import pandas as pd
9+
import re
10+
711
from ads.opctl import logger
12+
from ads.opctl.operator.lowcode.common.const import DataColumns
813
from ads.opctl.operator.lowcode.common.errors import (
9-
InvalidParameterError,
1014
DataMismatchError,
15+
InvalidParameterError,
1116
)
12-
from ads.opctl.operator.lowcode.common.const import DataColumns
1317
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
14-
import pandas as pd
15-
from abc import ABC
1618

1719

1820
class Transformations(ABC):
@@ -58,6 +60,7 @@ def run(self, data):
5860
5961
"""
6062
clean_df = self._remove_trailing_whitespace(data)
63+
clean_df = self._normalize_column_names(clean_df)
6164
if self.name == "historical_data":
6265
self._check_historical_dataset(clean_df)
6366
clean_df = self._set_series_id_column(clean_df)
@@ -95,8 +98,11 @@ def run(self, data):
9598
def _remove_trailing_whitespace(self, df):
9699
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
97100

101+
def _normalize_column_names(self, df):
102+
return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
103+
98104
def _set_series_id_column(self, df):
99-
self._target_category_columns_map = dict()
105+
self._target_category_columns_map = {}
100106
if not self.target_category_columns:
101107
df[DataColumns.Series] = "Series 1"
102108
self.has_artificial_series = True
@@ -125,10 +131,10 @@ def _format_datetime_col(self, df):
125131
df[self.dt_column_name] = pd.to_datetime(
126132
df[self.dt_column_name], format=self.dt_column_format
127133
)
128-
except:
134+
except Exception as ee:
129135
raise InvalidParameterError(
130136
f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
131-
)
137+
) from ee
132138
return df
133139

134140
def _set_multi_index(self, df):
@@ -242,7 +248,6 @@ def _check_historical_dataset(self, df):
242248
"Class": "A",
243249
"Num": 2
244250
},
245-
246251
}
247252
"""
248253

0 commit comments

Comments
 (0)