Skip to content

Commit 49fb295

Browse files
MechCoderjnothman
authored andcommitted
[MRG+2] Fix repr on isotropic kernels when a 1-D length scale is given (scikit-learn#7259)
1 parent 0f2a00f commit 49fb295

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

sklearn/gaussian_process/kernels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def __repr__(self):
12041204
self.length_scale)))
12051205
else: # isotropic
12061206
return "{0}(length_scale={1:.3g})".format(
1207-
self.__class__.__name__, self.length_scale)
1207+
self.__class__.__name__, np.ravel(self.length_scale)[0])
12081208

12091209

12101210
class Matern(RBF):
@@ -1348,9 +1348,10 @@ def __repr__(self):
13481348
self.__class__.__name__,
13491349
", ".join(map("{0:.3g}".format, self.length_scale)),
13501350
self.nu)
1351-
else: # isotropic
1351+
else:
13521352
return "{0}(length_scale={1:.3g}, nu={2:.3g})".format(
1353-
self.__class__.__name__, self.length_scale, self.nu)
1353+
self.__class__.__name__, np.ravel(self.length_scale)[0],
1354+
self.nu)
13541355

13551356

13561357
class RationalQuadratic(StationaryKernelMixin, NormalizedKernelMixin, Kernel):

sklearn/gaussian_process/tests/test_kernels.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
4.0 * Matern(length_scale=[0.5, 0.5], nu=2.5),
4242
RationalQuadratic(length_scale=0.5, alpha=1.5),
4343
ExpSineSquared(length_scale=0.5, periodicity=1.5),
44-
DotProduct(sigma_0=2.0), DotProduct(sigma_0=2.0) ** 2]
44+
DotProduct(sigma_0=2.0), DotProduct(sigma_0=2.0) ** 2,
45+
RBF(length_scale=[2.0]), Matern(length_scale=[2.0])]
4546
for metric in PAIRWISE_KERNEL_FUNCTIONS:
4647
if metric in ["additive_chi2", "chi2"]:
4748
continue
@@ -304,3 +305,10 @@ def test_set_get_params():
304305
kernel.set_params(**{hyperparameter.name: value})
305306
assert_almost_equal(np.exp(kernel.theta[index]), value)
306307
index += 1
308+
309+
310+
def test_repr_kernels():
311+
"""Smoke-test for repr in kernels."""
312+
313+
for kernel in kernels:
314+
repr(kernel)

0 commit comments

Comments
 (0)