Skip to content

Commit 7e38502

Browse files
committed
Successful merge with the missing value support
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent f82f258 commit 7e38502

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

sklearn/tree/_classes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def _fit(
388388
X,
389389
y,
390390
sample_weight,
391+
feature_has_missing,
391392
min_samples_leaf,
392393
min_weight_leaf,
393394
max_leaf_nodes,
@@ -403,6 +404,7 @@ def _build_tree(
403404
X,
404405
y,
405406
sample_weight,
407+
feature_has_missing,
406408
min_samples_leaf,
407409
min_weight_leaf,
408410
max_leaf_nodes,

sklearn/tree/tests/test_tree.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_xor():
300300
clf.fit(X, y)
301301
assert clf.score(X, y) == 1.0, "Failed with {0}".format(name)
302302

303-
clf = Tree(random_state=0, max_features=X.shape[1])
303+
clf = Tree(random_state=0, max_features=1)
304304
clf.fit(X, y)
305305
assert clf.score(X, y) == 1.0, "Failed with {0}".format(name)
306306

@@ -440,7 +440,7 @@ def test_importances():
440440
X, y = datasets.make_classification(
441441
n_samples=5000,
442442
n_features=10,
443-
n_informative=4,
443+
n_informative=3,
444444
n_redundant=0,
445445
n_repeated=0,
446446
shuffle=False,
@@ -455,7 +455,7 @@ def test_importances():
455455
n_important = np.sum(importances > 0.1)
456456

457457
assert importances.shape[0] == 10, "Failed with {0}".format(name)
458-
assert n_important == 4, "Failed with {0}".format(name)
458+
assert n_important == 3, "Failed with {0}".format(name)
459459

460460
# Check on iris that importances are the same for all builders
461461
clf = DecisionTreeClassifier(random_state=0)
@@ -466,9 +466,9 @@ def test_importances():
466466
assert_array_equal(clf.feature_importances_, clf2.feature_importances_)
467467

468468

469-
@pytest.mark.parametrize("clf", [DecisionTreeClassifier()])
470-
def test_importances_raises(clf):
469+
def test_importances_raises():
471470
# Check if variable importance before fit raises ValueError.
471+
clf = DecisionTreeClassifier()
472472
with pytest.raises(ValueError):
473473
getattr(clf, "feature_importances_")
474474

@@ -653,7 +653,6 @@ def test_min_samples_leaf():
653653
est.fit(X, y)
654654
out = est.tree_.apply(X)
655655
node_counts = np.bincount(out)
656-
657656
# drop inner nodes
658657
leaf_count = node_counts[node_counts != 0]
659658
assert np.min(leaf_count) > 4, "Failed with {0}".format(name)
@@ -678,7 +677,7 @@ def check_min_weight_fraction_leaf(name, datasets, sparse=False):
678677
else:
679678
X = DATASETS[datasets]["X"].astype(np.float32)
680679
y = DATASETS[datasets]["y"]
681-
rng = np.random.RandomState(42)
680+
682681
weights = rng.rand(X.shape[0])
683682
total_weight = np.sum(weights)
684683

@@ -829,7 +828,7 @@ def test_min_impurity_decrease():
829828
)
830829
# Check with a much lower value of 0.0001
831830
est3 = TreeEstimator(
832-
max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=0.0001, random_state=1
831+
max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=0.0001, random_state=0
833832
)
834833
# Check with a much lower value of 0.1
835834
est4 = TreeEstimator(
@@ -919,7 +918,6 @@ def test_pickle():
919918
est2 = pickle.loads(serialized_object)
920919
assert type(est2) == est.__class__
921920

922-
# score should match before/after pickling
923921
score2 = est2.score(X, y)
924922
assert (
925923
score == score2
@@ -1033,6 +1031,7 @@ def test_memory_layout():
10331031
ALL_TREES.items(), [np.float64, np.float32]
10341032
):
10351033
est = TreeEstimator(random_state=0)
1034+
10361035
# Nothing
10371036
X = np.asarray(iris.data, dtype=dtype)
10381037
y = iris.target
@@ -1053,11 +1052,6 @@ def test_memory_layout():
10531052
y = iris.target
10541053
assert_array_equal(est.fit(X, y).predict(X), y)
10551054

1056-
# Strided
1057-
X = np.asarray(iris.data[::3], dtype=dtype)
1058-
y = iris.target[::3]
1059-
assert_array_equal(est.fit(X, y).predict(X), y)
1060-
10611055
# csr matrix
10621056
X = csr_matrix(iris.data, dtype=dtype)
10631057
y = iris.target
@@ -1068,6 +1062,11 @@ def test_memory_layout():
10681062
y = iris.target
10691063
assert_array_equal(est.fit(X, y).predict(X), y)
10701064

1065+
# Strided
1066+
X = np.asarray(iris.data[::3], dtype=dtype)
1067+
y = iris.target[::3]
1068+
assert_array_equal(est.fit(X, y).predict(X), y)
1069+
10711070

10721071
def test_sample_weight():
10731072
# Check sample weighting.
@@ -1261,7 +1260,7 @@ def test_behaviour_constant_feature_after_splits():
12611260
y = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3]
12621261
for name, TreeEstimator in ALL_TREES.items():
12631262
# do not check extra random trees
1264-
if all(_name not in name for _name in ["ExtraTree"]):
1263+
if "ExtraTree" not in name:
12651264
est = TreeEstimator(random_state=0, max_features=1)
12661265
est.fit(X, y)
12671266
assert est.tree_.max_depth == 2
@@ -1587,7 +1586,6 @@ def check_min_weight_leaf_split_level(name):
15871586
sample_weight = [0.2, 0.2, 0.2, 0.2, 0.2]
15881587
_check_min_weight_leaf_split_level(TreeEstimator, X, y, sample_weight)
15891588

1590-
# skip for sparse inputs
15911589
_check_min_weight_leaf_split_level(TreeEstimator, csc_matrix(X), y, sample_weight)
15921590

15931591

@@ -1646,7 +1644,6 @@ def check_decision_path(name):
16461644
# Assert that leaves index are correct
16471645
leaves = est.apply(X)
16481646
leave_indicator = [node_indicator[i, j] for i, j in enumerate(leaves)]
1649-
16501647
assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples))
16511648

16521649
# Ensure only one leave node per sample
@@ -1933,7 +1930,6 @@ def assert_is_subtree(tree, subtree):
19331930
def test_apply_path_readonly_all_trees(name, splitter, X_format):
19341931
dataset = DATASETS["clf_small"]
19351932
X_small = dataset["X"].astype(tree._tree.DTYPE, copy=False)
1936-
19371933
if X_format == "dense":
19381934
X_readonly = create_memmap_backed_data(X_small)
19391935
else:

0 commit comments

Comments
 (0)