Skip to content

Commit d9c9bd4

Browse files
committed
Merging main
Signed-off-by: Adam Li <adam2392@gmail.com>
2 parents 600187a + 0b0b90b commit d9c9bd4

File tree

126 files changed

+1865
-1024
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+1865
-1024
lines changed

azure-pipelines.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ jobs:
4040
- bash: |
4141
./build_tools/linting.sh
4242
displayName: Run linters
43+
- bash: |
44+
pip install ninja meson scipy
45+
python build_tools/check-meson-openmp-dependencies.py
46+
displayName: Run Meson OpenMP checks
47+
4348
4449
- template: build_tools/azure/posix.yml
4550
parameters:
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Check that OpenMP dependencies are correctly defined in meson.build files.
3+
4+
This is based on trying to make sure the the following two things match:
5+
- the Cython files using OpenMP (based on a git grep regex)
6+
- the Cython extension modules that are built with OpenMP compiler flags (based
7+
on meson introspect json output)
8+
"""
9+
10+
import json
11+
import re
12+
import subprocess
13+
from pathlib import Path
14+
15+
16+
def has_source_openmp_flags(target_source):
17+
return any("openmp" in arg for arg in target_source["parameters"])
18+
19+
20+
def has_openmp_flags(target):
21+
"""Return whether target sources use OpenMP flags.
22+
23+
Make sure that both compiler and linker source use OpenMP.
24+
Look at `get_meson_info` docstring to see what `target` looks like.
25+
"""
26+
target_sources = target["target_sources"]
27+
28+
target_use_openmp_flags = any(
29+
has_source_openmp_flags(target_source) for target_source in target_sources
30+
)
31+
32+
if not target_use_openmp_flags:
33+
return False
34+
35+
# When the target use OpenMP we expect a compiler + linker source and we
36+
# want to make sure that both the compiler and the linker use OpenMP
37+
assert len(target_sources) == 2
38+
compiler_source, linker_source = target_sources
39+
assert "compiler" in compiler_source
40+
assert "linker" in linker_source
41+
42+
compiler_use_openmp_flags = any(
43+
"openmp" in arg for arg in compiler_source["parameters"]
44+
)
45+
linker_use_openmp_flags = any(
46+
"openmp" in arg for arg in linker_source["parameters"]
47+
)
48+
49+
assert compiler_use_openmp_flags == linker_use_openmp_flags
50+
return compiler_use_openmp_flags
51+
52+
53+
def get_canonical_name_meson(target, build_path):
54+
"""Return a name based on generated shared library.
55+
56+
The goal is to return a name that can be easily matched with the output
57+
from `git_grep_info`.
58+
59+
Look at `get_meson_info` docstring to see what `target` looks like.
60+
"""
61+
# Expect a list with one element with the name of the shared library
62+
assert len(target["filename"]) == 1
63+
shared_library_path = Path(target["filename"][0])
64+
shared_library_relative_path = shared_library_path.relative_to(
65+
build_path.absolute()
66+
)
67+
# Needed on Windows to match git grep output
68+
rel_path = shared_library_relative_path.as_posix()
69+
# OS-specific naming of the shared library .cpython- on POSIX and
70+
# something like .cp312- on Windows
71+
pattern = r"\.(cpython|cp\d+)-.+"
72+
return re.sub(pattern, "", str(rel_path))
73+
74+
75+
def get_canonical_name_git_grep(filename):
76+
"""Return name based on filename.
77+
78+
The goal is to return a name that can easily be matched with the output
79+
from `get_meson_info`.
80+
"""
81+
return re.sub(r"\.pyx(\.tp)?", "", filename)
82+
83+
84+
def get_meson_info():
85+
"""Return names of extension that use OpenMP based on meson introspect output.
86+
87+
The meson introspect json info is a list of targets where a target is a dict
88+
that looks like this (parts not used in this script are not shown for simplicity):
89+
{
90+
'name': '_k_means_elkan.cpython-312-x86_64-linux-gnu',
91+
'filename': [
92+
'<meson_build_dir>/sklearn/cluster/_k_means_elkan.cpython-312-x86_64-linux-gnu.so'
93+
],
94+
'target_sources': [
95+
{
96+
'compiler': ['ccache', 'cc'],
97+
'parameters': [
98+
'-Wall',
99+
'-std=c11',
100+
'-fopenmp',
101+
...
102+
],
103+
...
104+
},
105+
{
106+
'linker': ['cc'],
107+
'parameters': [
108+
'-shared',
109+
'-fPIC',
110+
'-fopenmp',
111+
...
112+
]
113+
}
114+
]
115+
}
116+
"""
117+
build_path = Path("build/introspect")
118+
subprocess.check_call(["meson", "setup", build_path, "--reconfigure"])
119+
120+
json_out = subprocess.check_output(
121+
["meson", "introspect", build_path, "--targets"], text=True
122+
)
123+
target_list = json.loads(json_out)
124+
meson_targets = [target for target in target_list if has_openmp_flags(target)]
125+
126+
return [get_canonical_name_meson(each, build_path) for each in meson_targets]
127+
128+
129+
def get_git_grep_info():
130+
"""Return names of extensions that use OpenMP based on git grep regex."""
131+
git_grep_filenames = subprocess.check_output(
132+
["git", "grep", "-lP", "cython.*parallel|_openmp_helpers"], text=True
133+
).splitlines()
134+
git_grep_filenames = [f for f in git_grep_filenames if ".pyx" in f]
135+
136+
return [get_canonical_name_git_grep(each) for each in git_grep_filenames]
137+
138+
139+
def main():
140+
from_meson = set(get_meson_info())
141+
from_git_grep = set(get_git_grep_info())
142+
143+
only_in_git_grep = from_git_grep - from_meson
144+
only_in_meson = from_meson - from_git_grep
145+
146+
msg = ""
147+
if only_in_git_grep:
148+
only_in_git_grep_msg = "\n".join(
149+
[f" {each}" for each in sorted(only_in_git_grep)]
150+
)
151+
msg += (
152+
"Some Cython files use OpenMP,"
153+
" but their meson.build is missing the openmp_dep dependency:\n"
154+
f"{only_in_git_grep_msg}\n\n"
155+
)
156+
157+
if only_in_meson:
158+
only_in_meson_msg = "\n".join([f" {each}" for each in sorted(only_in_meson)])
159+
msg += (
160+
"Some Cython files do not use OpenMP,"
161+
" you should remove openmp_dep from their meson.build:\n"
162+
f"{only_in_meson_msg}\n\n"
163+
)
164+
165+
if from_meson != from_git_grep:
166+
raise ValueError(
167+
f"Some issues have been found in Meson OpenMP dependencies:\n\n{msg}"
168+
)
169+
170+
171+
if __name__ == "__main__":
172+
main()

doc/api_reference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ def _get_submodule(module_name, submodule_name):
11831183
"validation.check_symmetric",
11841184
"validation.column_or_1d",
11851185
"validation.has_fit_parameter",
1186+
"validation.validate_data",
11861187
],
11871188
},
11881189
{

doc/developers/develop.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -562,15 +562,6 @@ for your estimator's tags. For example::
562562
You can create a new subclass of :class:`~sklearn.utils.Tags` if you wish
563563
to add new tags to the existing set.
564564

565-
In addition to the tags, estimators also need to declare any non-optional
566-
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
567-
which is a list or tuple. If ``_required_parameters`` is only
568-
``["estimator"]`` or ``["base_estimator"]``, then the estimator will be
569-
instantiated with an instance of ``LogisticRegression`` (or
570-
``RidgeRegression`` if the estimator is a regressor) in the tests. The choice
571-
of these two models is somewhat idiosyncratic but both should provide robust
572-
closed-form solutions.
573-
574565
.. _developer_api_set_output:
575566

576567
Developer API for `set_output`

doc/whats_new/v1.6.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Version 1.6.0
2525
Changes impacting many modules
2626
------------------------------
2727

28+
- |API| :func:`utils.validation.validate_data` is introduced and replaces previously
29+
private `base.BaseEstimator._validate_data` method. This is intended for third party
30+
estimator developers, who should use this function in most cases instead of
31+
:func:`utils.validation.check_array` and :func:`utils.validation.check_X_y`.
32+
:pr:`29696` by `Adrin Jalali`_.
33+
2834
- |Enhancement| `__sklearn_tags__` was introduced for setting tags in estimators.
2935
More details in :ref:`estimator_tags`.
3036
:pr:`22606` by `Thomas Fan`_ and :pr:`29677` by `Adrin Jalali`_.
@@ -247,6 +253,10 @@ Changelog
247253
:mod:`sklearn.linear_model`
248254
...........................
249255

256+
- |Fix| :class:`linear_model.LogisticRegressionCV` corrects sample weight handling
257+
for the calculation of test scores.
258+
:pr:`29419` by :user:`Shruti Nath <snath-xoc>`.
259+
250260
- |API| Deprecates `copy_X` in :class:`linear_model.TheilSenRegressor` as the parameter
251261
has no effect. `copy_X` will be removed in 1.8.
252262
:pr:`29105` by :user:`Adam Li <adam2392>`.
@@ -271,6 +281,10 @@ Changelog
271281
:pr:`29210` by :user:`Marc Torrellas Socastro <marctorsoc>` and
272282
:user:`Stefanie Senger <StefanieSenger>`.
273283

284+
- |Efficiency| :func:`sklearn.metrics.classification_report` is now faster by caching
285+
classification labels.
286+
:pr:`29738` by `Adrin Jalali`_.
287+
274288
- |API| scoring="neg_max_error" should be used instead of
275289
scoring="max_error" which is now deprecated.
276290
:pr:`29462` by :user:`Farid "Freddie" Taba <artificialfintelligence>`.

examples/gaussian_process/plot_gpr_noisy.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def target_generator(X, add_noise=False):
3333
# %%
3434
# Let's have a look to the target generator where we will not add any noise to
3535
# observe the signal that we would like to predict.
36-
X = np.linspace(0, 5, num=30).reshape(-1, 1)
36+
X = np.linspace(0, 5, num=80).reshape(-1, 1)
3737
y = target_generator(X, add_noise=False)
3838

3939
# %%
@@ -88,7 +88,7 @@ def target_generator(X, add_noise=False):
8888
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
8989

9090
kernel = 1.0 * RBF(length_scale=1e1, length_scale_bounds=(1e-2, 1e3)) + WhiteKernel(
91-
noise_level=1, noise_level_bounds=(1e-5, 1e1)
91+
noise_level=1, noise_level_bounds=(1e-10, 1e1)
9292
)
9393
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0.0)
9494
gpr.fit(X_train, y_train)
@@ -97,7 +97,7 @@ def target_generator(X, add_noise=False):
9797
# %%
9898
plt.plot(X, y, label="Expected signal")
9999
plt.scatter(x=X_train[:, 0], y=y_train, color="black", alpha=0.4, label="Observations")
100-
plt.errorbar(X, y_mean, y_std)
100+
plt.errorbar(X, y_mean, y_std, label="Posterior mean ± std")
101101
plt.legend()
102102
plt.xlabel("X")
103103
plt.ylabel("y")
@@ -109,15 +109,18 @@ def target_generator(X, add_noise=False):
109109
fontsize=8,
110110
)
111111
# %%
112-
# We see that the optimum kernel found still have a high noise level and
113-
# an even larger length scale. Furthermore, we observe that the
114-
# model does not provide faithful predictions.
112+
# We see that the optimum kernel found still has a high noise level and an even
113+
# larger length scale. The length scale reaches the maximum bound that we
114+
# allowed for this parameter and we got a warning as a result.
115115
#
116-
# Now, we will initialize the
117-
# :class:`~sklearn.gaussian_process.kernels.RBF` with a
118-
# larger `length_scale` and the
119-
# :class:`~sklearn.gaussian_process.kernels.WhiteKernel`
120-
# with a smaller noise level lower bound.
116+
# More importantly, we observe that the model does not provide useful
117+
# predictions: the mean prediction seems to be constant: it does not follow the
118+
# expected noise-free signal.
119+
#
120+
# Now, we will initialize the :class:`~sklearn.gaussian_process.kernels.RBF`
121+
# with a larger `length_scale` initial value and the
122+
# :class:`~sklearn.gaussian_process.kernels.WhiteKernel` with a smaller initial
123+
# noise level lower while keeping the parameter bounds unchanged.
121124
kernel = 1.0 * RBF(length_scale=1e-1, length_scale_bounds=(1e-2, 1e3)) + WhiteKernel(
122125
noise_level=1e-2, noise_level_bounds=(1e-10, 1e1)
123126
)
@@ -128,7 +131,7 @@ def target_generator(X, add_noise=False):
128131
# %%
129132
plt.plot(X, y, label="Expected signal")
130133
plt.scatter(x=X_train[:, 0], y=y_train, color="black", alpha=0.4, label="Observations")
131-
plt.errorbar(X, y_mean, y_std)
134+
plt.errorbar(X, y_mean, y_std, label="Posterior mean ± std")
132135
plt.legend()
133136
plt.xlabel("X")
134137
plt.ylabel("y")
@@ -153,21 +156,19 @@ def target_generator(X, add_noise=False):
153156
# for different hyperparameters to get a sense of the local minima.
154157
from matplotlib.colors import LogNorm
155158

156-
length_scale = np.logspace(-2, 4, num=50)
157-
noise_level = np.logspace(-2, 1, num=50)
159+
length_scale = np.logspace(-2, 4, num=80)
160+
noise_level = np.logspace(-2, 1, num=80)
158161
length_scale_grid, noise_level_grid = np.meshgrid(length_scale, noise_level)
159162

160163
log_marginal_likelihood = [
161164
gpr.log_marginal_likelihood(theta=np.log([0.36, scale, noise]))
162165
for scale, noise in zip(length_scale_grid.ravel(), noise_level_grid.ravel())
163166
]
164-
log_marginal_likelihood = np.reshape(
165-
log_marginal_likelihood, newshape=noise_level_grid.shape
166-
)
167+
log_marginal_likelihood = np.reshape(log_marginal_likelihood, noise_level_grid.shape)
167168

168169
# %%
169170
vmin, vmax = (-log_marginal_likelihood).min(), 50
170-
level = np.around(np.logspace(np.log10(vmin), np.log10(vmax), num=50), decimals=1)
171+
level = np.around(np.logspace(np.log10(vmin), np.log10(vmax), num=20), decimals=1)
171172
plt.contour(
172173
length_scale_grid,
173174
noise_level_grid,
@@ -184,8 +185,43 @@ def target_generator(X, add_noise=False):
184185
plt.show()
185186

186187
# %%
187-
# We see that there are two local minima that correspond to the combination
188-
# of hyperparameters previously found. Depending on the initial values for the
189-
# hyperparameters, the gradient-based optimization might converge whether or
190-
# not to the best model. It is thus important to repeat the optimization
191-
# several times for different initializations.
188+
#
189+
# We see that there are two local minima that correspond to the combination of
190+
# hyperparameters previously found. Depending on the initial values for the
191+
# hyperparameters, the gradient-based optimization might or might not
192+
# converge to the best model. It is thus important to repeat the optimization
193+
# several times for different initializations. This can be done by setting the
194+
# `n_restarts_optimizer` parameter of the
195+
# :class:`~sklearn.gaussian_process.GaussianProcessRegressor` class.
196+
#
197+
# Let's try again to fit our model with the bad initial values but this time
198+
# with 10 random restarts.
199+
200+
kernel = 1.0 * RBF(length_scale=1e1, length_scale_bounds=(1e-2, 1e3)) + WhiteKernel(
201+
noise_level=1, noise_level_bounds=(1e-10, 1e1)
202+
)
203+
gpr = GaussianProcessRegressor(
204+
kernel=kernel, alpha=0.0, n_restarts_optimizer=10, random_state=0
205+
)
206+
gpr.fit(X_train, y_train)
207+
y_mean, y_std = gpr.predict(X, return_std=True)
208+
209+
# %%
210+
plt.plot(X, y, label="Expected signal")
211+
plt.scatter(x=X_train[:, 0], y=y_train, color="black", alpha=0.4, label="Observations")
212+
plt.errorbar(X, y_mean, y_std, label="Posterior mean ± std")
213+
plt.legend()
214+
plt.xlabel("X")
215+
plt.ylabel("y")
216+
_ = plt.title(
217+
(
218+
f"Initial: {kernel}\nOptimum: {gpr.kernel_}\nLog-Marginal-Likelihood: "
219+
f"{gpr.log_marginal_likelihood(gpr.kernel_.theta)}"
220+
),
221+
fontsize=8,
222+
)
223+
224+
# %%
225+
#
226+
# As we hoped, random restarts allow the optimization to find the best set
227+
# of hyperparameters despite the bad initial values.

0 commit comments

Comments
 (0)