Skip to content

Commit de1bd1c

Browse files
authored
Merge pull request #49 from m3dev/add_lgbm_mode_factorys
Add lgbm and catboost model factories
2 parents fcac833 + 98d69ec commit de1bd1c

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

redshells/factory/optuna_param_factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,45 @@ def _xgbclassifier_default(trial: optuna.trial.Trial):
2828
return param
2929

3030

31+
def _lgbmclassifier_default(trial: optuna.trial.Trial):
32+
# TODO: using LightGBMTuner
33+
params = {
34+
'boosting_type': trial.suggest_categorical('boosting', ['gbdt', 'dart', 'goss']),
35+
'objective': 'binary',
36+
'metric': ['binary', 'binary_error', 'auc'],
37+
'num_leaves': trial.suggest_int("num_leaves", 10, 500),
38+
'learning_rate': trial.suggest_loguniform("learning_rate", 1e-5, 1),
39+
'feature_fraction': trial.suggest_uniform("feature_fraction", 0.0, 1.0),
40+
}
41+
if params['boosting_type'] == 'dart':
42+
params['drop_rate'] = trial.suggest_loguniform('drop_rate', 1e-8, 1.0)
43+
params['skip_drop'] = trial.suggest_loguniform('skip_drop', 1e-8, 1.0)
44+
if params['boosting_type'] == 'goss':
45+
params['top_rate'] = trial.suggest_uniform('top_rate', 0.0, 1.0)
46+
params['other_rate'] = trial.suggest_uniform('other_rate', 0.0, 1.0 - params['top_rate'])
47+
48+
return params
49+
50+
51+
def _catboostclassifier_default(trial: optuna.trial.Trial):
52+
params = {
53+
'iterations': trial.suggest_int('iterations', 50, 300),
54+
'depth': trial.suggest_int('depth', 4, 10),
55+
'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 0.3),
56+
'random_strength': trial.suggest_int('random_strength', 0, 100),
57+
'bagging_temperature': trial.suggest_loguniform('bagging_temperature', 0.01, 100.00),
58+
'od_type': trial.suggest_categorical('od_type', ['IncToDec', 'Iter']),
59+
'od_wait': trial.suggest_int('od_wait', 10, 50)
60+
}
61+
62+
return params
63+
3164
class _OptunaParamFactory(metaclass=Singleton):
3265
def __init__(self):
3366
self._rules = dict()
3467
self._rules['XGBClassifier_default'] = _xgbclassifier_default
68+
self._rules['LGBMClassifier_default'] = _lgbmclassifier_default
69+
self._rules['CatBoostClassifier_default'] = _catboostclassifier_default
3570

3671
def get(self, key: str, trial: optuna.trial.Trial):
3772
if key not in self._rules:

redshells/factory/prediction_model_factory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ def __init__(self):
1818
except ImportError:
1919
pass
2020

21+
try:
22+
import lightgbm
23+
self._models['LGBMClassifier'] = lightgbm.LGBMClassifier
24+
except ImportError:
25+
pass
26+
27+
try:
28+
import catboost
29+
self._models['CatBoostClassifier'] = catboost.CatBoostClassifier
30+
except ImportError:
31+
pass
32+
2133
def get(self, key: str):
2234
if key in self._models:
2335
return self._models[key]

0 commit comments

Comments
 (0)