Skip to content

Commit aa2131f

Browse files
RafaAyGarlestevejeremiedbb
authored
EFF Make GaussianProcessRegressor.predict faster when return_std and return_cov are false (scikit-learn#31431)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 36ef203 commit aa2131f

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- make :class:`GaussianProcessRegressor.predict` faster when `return_cov` and
2+
`return_std` are both `False`.
3+
By :user:`Rafael Ayllón Gavilán <RafaAyGar>`.

sklearn/gaussian_process/_gpr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,9 @@ def predict(self, X, return_std=False, return_cov=False):
450450
if y_mean.ndim > 1 and y_mean.shape[1] == 1:
451451
y_mean = np.squeeze(y_mean, axis=1)
452452

453+
if not return_cov and not return_std:
454+
return y_mean
455+
453456
# Alg 2.1, page 19, line 5 -> v = L \ K(X_test, X_train)^T
454457
V = solve_triangular(
455458
self.L_, K_trans.T, lower=GPR_CHOLESKY_LOWER, check_finite=False
@@ -467,7 +470,7 @@ def predict(self, X, return_std=False, return_cov=False):
467470
y_cov = np.squeeze(y_cov, axis=2)
468471

469472
return y_mean, y_cov
470-
elif return_std:
473+
else: # return_std
471474
# Compute variance of predictive distribution
472475
# Use einsum to avoid explicitly forming the large matrix
473476
# V^T @ V just to extract its diagonal afterward.
@@ -492,8 +495,6 @@ def predict(self, X, return_std=False, return_cov=False):
492495
y_var = np.squeeze(y_var, axis=1)
493496

494497
return y_mean, np.sqrt(y_var)
495-
else:
496-
return y_mean
497498

498499
def sample_y(self, X, n_samples=1, random_state=0):
499500
"""Draw samples from Gaussian process and evaluate at X.

sklearn/gaussian_process/tests/test_gpr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,3 +847,15 @@ def test_gpr_predict_input_not_modified():
847847
_, _ = gpr.predict(X2, return_std=True)
848848

849849
assert_allclose(X2, X2_copy)
850+
851+
852+
@pytest.mark.parametrize("kernel", kernels)
853+
def test_gpr_predict_no_cov_no_std_return(kernel):
854+
"""
855+
Check that only y_mean is returned when return_cov=False and
856+
return_std=False.
857+
"""
858+
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
859+
y_pred = gpr.predict(X, return_cov=False, return_std=False)
860+
861+
assert_allclose(y_pred, y)

0 commit comments

Comments
 (0)