@@ -28,10 +28,45 @@ def _xgbclassifier_default(trial: optuna.trial.Trial):
28
28
return param
29
29
30
30
31
+ def _lgbmclassifier_default (trial : optuna .trial .Trial ):
32
+ # TODO: using LightGBMTuner
33
+ params = {
34
+ 'boosting_type' : trial .suggest_categorical ('boosting' , ['gbdt' , 'dart' , 'goss' ]),
35
+ 'objective' : 'binary' ,
36
+ 'metric' : ['binary' , 'binary_error' , 'auc' ],
37
+ 'num_leaves' : trial .suggest_int ("num_leaves" , 10 , 500 ),
38
+ 'learning_rate' : trial .suggest_loguniform ("learning_rate" , 1e-5 , 1 ),
39
+ 'feature_fraction' : trial .suggest_uniform ("feature_fraction" , 0.0 , 1.0 ),
40
+ }
41
+ if params ['boosting_type' ] == 'dart' :
42
+ params ['drop_rate' ] = trial .suggest_loguniform ('drop_rate' , 1e-8 , 1.0 )
43
+ params ['skip_drop' ] = trial .suggest_loguniform ('skip_drop' , 1e-8 , 1.0 )
44
+ if params ['boosting_type' ] == 'goss' :
45
+ params ['top_rate' ] = trial .suggest_uniform ('top_rate' , 0.0 , 1.0 )
46
+ params ['other_rate' ] = trial .suggest_uniform ('other_rate' , 0.0 , 1.0 - params ['top_rate' ])
47
+
48
+ return params
49
+
50
+
51
+ def _catboostclassifier_default (trial : optuna .trial .Trial ):
52
+ params = {
53
+ 'iterations' : trial .suggest_int ('iterations' , 50 , 300 ),
54
+ 'depth' : trial .suggest_int ('depth' , 4 , 10 ),
55
+ 'learning_rate' : trial .suggest_loguniform ('learning_rate' , 0.01 , 0.3 ),
56
+ 'random_strength' : trial .suggest_int ('random_strength' , 0 , 100 ),
57
+ 'bagging_temperature' : trial .suggest_loguniform ('bagging_temperature' , 0.01 , 100.00 ),
58
+ 'od_type' : trial .suggest_categorical ('od_type' , ['IncToDec' , 'Iter' ]),
59
+ 'od_wait' : trial .suggest_int ('od_wait' , 10 , 50 )
60
+ }
61
+
62
+ return params
63
+
31
64
class _OptunaParamFactory (metaclass = Singleton ):
32
65
def __init__ (self ):
33
66
self ._rules = dict ()
34
67
self ._rules ['XGBClassifier_default' ] = _xgbclassifier_default
68
+ self ._rules ['LGBMClassifier_default' ] = _lgbmclassifier_default
69
+ self ._rules ['CatBoostClassifier_default' ] = _catboostclassifier_default
35
70
36
71
def get (self , key : str , trial : optuna .trial .Trial ):
37
72
if key not in self ._rules :
0 commit comments