Skip to content

Commit 71728e9

Browse files
committed
update dir
1 parent 20545d6 commit 71728e9

File tree

6 files changed

+56
-6
lines changed

6 files changed

+56
-6
lines changed

__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__all__=["ensemble"]
1+
__all__=["cgb"]
22

cgb/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .cgb import cgb_clf
2+
from .cgb import cgb_reg
3+
4+
__all__ = ["cgb_clf", "cgb_reg"]
File renamed without changes.

ensemble/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

test/tets_clf.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,50 @@
1-
#%%
2-
from ensemble import cgb_clf
1+
from sklearn.model_selection import train_test_split
2+
from cgb import cgb_clf, cgb_reg
3+
import sklearn.datasets as dt
4+
import warnings
5+
6+
warnings.simplefilter("ignore")
7+
8+
9+
def model(clf=True):
10+
11+
if clf:
12+
X, y = dt.load_iris(return_X_y=True)
13+
14+
x_train, x_test, y_train, y_test = train_test_split(X,
15+
y,
16+
test_size=0.2)
17+
18+
model_ = cgb_clf(max_depth=5,
19+
subsample=0.5,
20+
max_features='sqrt',
21+
learning_rate=0.05,
22+
random_state=1,
23+
criterion="mse",
24+
loss="log_loss",
25+
n_estimators=100)
26+
27+
else:
28+
X, y = dt.make_regression(n_targets=3)
29+
30+
x_train, x_test, y_train, y_test = train_test_split(X,
31+
y,
32+
test_size=0.2)
33+
model_ = cgb_reg(learning_rate=0.1,
34+
subsample=1,
35+
max_features="sqrt",
36+
loss='ls',
37+
n_estimators=100,
38+
max_depth=3,
39+
random_state=2)
40+
41+
model_.fit(x_train, y_train)
42+
print(model_.score(x_test, y_test))
43+
44+
45+
if __name__ == "__main__":
46+
print('clf')
47+
model(clf=True)
48+
print('-----')
49+
print('reg')
50+
model(clf=False)

0 commit comments

Comments
 (0)