Skip to content

Commit 21ccb30

Browse files
authored
[ENH] Adding leaf node samples to be stored when "quantile" tree is turned on (#45)
#### Reference Issues/PRs Addresses the quantile-trees part of: neurodata/treeple#29 #### What does this implement/fix? Explain your changes. 1. Stores for each leaf node a 2D numpy array of the y-samples (remember `y` is (n_samples, n_outputs)) 2. Does this all the way in Criterion 3. Only supports supervised tree/splitter/criterion 4. merges in `main` changes. #### Any other comments? <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! --> --------- Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 9b07f2a commit 21ccb30

File tree

141 files changed

+2511
-797
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+2511
-797
lines changed

doc/authors_emeritus.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- Wei Li
2121
- Paolo Losi
2222
- Gilles Louppe
23+
- Chiara Marmo
2324
- Vincent Michel
2425
- Jarrod Millman
2526
- Alexandre Passos

doc/contributor_experience_team.rst

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
<p>Lucy Liu</p>
1919
</div>
2020
<div>
21+
<a href='https://github.com/MaxwellLZH'><img src='https://avatars.githubusercontent.com/u/16646940?v=4' class='avatar' /></a> <br />
22+
<p>Maxwell Liu</p>
23+
</div>
24+
<div>
2125
<a href='https://github.com/jmloyola'><img src='https://avatars.githubusercontent.com/u/2133361?v=4' class='avatar' /></a> <br />
2226
<p>Juan Martin Loyola</p>
2327
</div>
@@ -26,14 +30,6 @@
2630
<p>Sylvain Marié</p>
2731
</div>
2832
<div>
29-
<a href='https://github.com/cmarmo'><img src='https://avatars.githubusercontent.com/u/1662261?v=4' class='avatar' /></a> <br />
30-
<p>Chiara Marmo</p>
31-
</div>
32-
<div>
33-
<a href='https://github.com/MaxwellLZH'><img src='https://avatars.githubusercontent.com/u/16646940?v=4' class='avatar' /></a> <br />
34-
<p>Maxwell Liu</p>
35-
</div>
36-
<div>
3733
<a href='https://github.com/norbusan'><img src='https://avatars.githubusercontent.com/u/1735589?v=4' class='avatar' /></a> <br />
3834
<p>Norbert Preining</p>
3935
</div>

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,7 @@ Visualization
12471247
:template: display_only_from_estimator.rst
12481248

12491249
model_selection.LearningCurveDisplay
1250+
model_selection.ValidationCurveDisplay
12501251

12511252
.. _multiclass_ref:
12521253

doc/modules/learning_curve.rst

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ The function :func:`validation_curve` can help in this case::
7171
>>> import numpy as np
7272
>>> from sklearn.model_selection import validation_curve
7373
>>> from sklearn.datasets import load_iris
74-
>>> from sklearn.linear_model import Ridge
74+
>>> from sklearn.svm import SVC
7575

7676
>>> np.random.seed(0)
7777
>>> X, y = load_iris(return_X_y=True)
@@ -80,30 +80,50 @@ The function :func:`validation_curve` can help in this case::
8080
>>> X, y = X[indices], y[indices]
8181

8282
>>> train_scores, valid_scores = validation_curve(
83-
... Ridge(), X, y, param_name="alpha", param_range=np.logspace(-7, 3, 3),
84-
... cv=5)
83+
... SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 3),
84+
... )
8585
>>> train_scores
86-
array([[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
87-
[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
88-
[0.51..., 0.52..., 0.49..., 0.47..., 0.49...]])
86+
array([[0.90..., 0.94..., 0.91..., 0.89..., 0.92...],
87+
[0.9... , 0.92..., 0.93..., 0.92..., 0.93...],
88+
[0.97..., 1... , 0.98..., 0.97..., 0.99...]])
8989
>>> valid_scores
90-
array([[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
91-
[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
92-
[0.46..., 0.25..., 0.50..., 0.49..., 0.52...]])
90+
array([[0.9..., 0.9... , 0.9... , 0.96..., 0.9... ],
91+
[0.9..., 0.83..., 0.96..., 0.96..., 0.93...],
92+
[1.... , 0.93..., 1.... , 1.... , 0.9... ]])
93+
94+
If you intend to plot the validation curves only, the class
95+
:class:`~sklearn.model_selection.ValidationCurveDisplay` is more direct than
96+
using matplotlib manually on the results of a call to :func:`validation_curve`.
97+
You can use the method
98+
:meth:`~sklearn.model_selection.ValidationCurveDisplay.from_estimator` similarly
99+
to :func:`validation_curve` to generate and plot the validation curve:
100+
101+
.. plot::
102+
:context: close-figs
103+
:align: center
104+
105+
from sklearn.datasets import load_iris
106+
from sklearn.model_selection import ValidationCurveDisplay
107+
from sklearn.svm import SVC
108+
from sklearn.utils import shuffle
109+
X, y = load_iris(return_X_y=True)
110+
X, y = shuffle(X, y, random_state=0)
111+
ValidationCurveDisplay.from_estimator(
112+
SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 10)
113+
)
93114

94115
If the training score and the validation score are both low, the estimator will
95116
be underfitting. If the training score is high and the validation score is low,
96117
the estimator is overfitting and otherwise it is working very well. A low
97118
training score and a high validation score is usually not possible. Underfitting,
98119
overfitting, and a working model are shown in the in the plot below where we vary
99-
the parameter :math:`\gamma` of an SVM on the digits dataset.
120+
the parameter `gamma` of an SVM with an RBF kernel on the digits dataset.
100121

101122
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_validation_curve_001.png
102123
:target: ../auto_examples/model_selection/plot_validation_curve.html
103124
:align: center
104125
:scale: 50%
105126

106-
107127
.. _learning_curve:
108128

109129
Learning curve

doc/visualizations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,4 @@ Display Objects
8989
metrics.PredictionErrorDisplay
9090
metrics.RocCurveDisplay
9191
model_selection.LearningCurveDisplay
92+
model_selection.ValidationCurveDisplay

0 commit comments

Comments
 (0)