Skip to content

Commit 55e1158

Browse files
montanalowmorenol
authored andcommitted
Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
1 parent cfa824d commit 55e1158

File tree

18 files changed

+1713
-25
lines changed

18 files changed

+1713
-25
lines changed

src/cluster/dbscan.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
109109
}
110110
}
111111

112+
/// DBSCAN grid search parameters
113+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114+
#[derive(Debug, Clone)]
115+
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
116+
/// a function that defines a distance between each pair of point in training data.
117+
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
118+
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
119+
pub distance: Vec<D>,
120+
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
121+
pub min_samples: Vec<usize>,
122+
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
123+
pub eps: Vec<T>,
124+
/// KNN algorithm to use.
125+
pub algorithm: Vec<KNNAlgorithmName>,
126+
}
127+
128+
/// DBSCAN grid search iterator
129+
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
130+
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
131+
current_distance: usize,
132+
current_min_samples: usize,
133+
current_eps: usize,
134+
current_algorithm: usize,
135+
}
136+
137+
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
138+
type Item = DBSCANParameters<T, D>;
139+
type IntoIter = DBSCANSearchParametersIterator<T, D>;
140+
141+
fn into_iter(self) -> Self::IntoIter {
142+
DBSCANSearchParametersIterator {
143+
dbscan_search_parameters: self,
144+
current_distance: 0,
145+
current_min_samples: 0,
146+
current_eps: 0,
147+
current_algorithm: 0,
148+
}
149+
}
150+
}
151+
152+
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
153+
type Item = DBSCANParameters<T, D>;
154+
155+
fn next(&mut self) -> Option<Self::Item> {
156+
if self.current_distance == self.dbscan_search_parameters.distance.len()
157+
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
158+
&& self.current_eps == self.dbscan_search_parameters.eps.len()
159+
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
160+
{
161+
return None;
162+
}
163+
164+
let next = DBSCANParameters {
165+
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
166+
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
167+
eps: self.dbscan_search_parameters.eps[self.current_eps],
168+
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
169+
};
170+
171+
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
172+
self.current_distance += 1;
173+
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
174+
self.current_distance = 0;
175+
self.current_min_samples += 1;
176+
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
177+
self.current_distance = 0;
178+
self.current_min_samples = 0;
179+
self.current_eps += 1;
180+
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
181+
self.current_distance = 0;
182+
self.current_min_samples = 0;
183+
self.current_eps = 0;
184+
self.current_algorithm += 1;
185+
} else {
186+
self.current_distance += 1;
187+
self.current_min_samples += 1;
188+
self.current_eps += 1;
189+
self.current_algorithm += 1;
190+
}
191+
192+
Some(next)
193+
}
194+
}
195+
196+
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
197+
fn default() -> Self {
198+
let default_params = DBSCANParameters::default();
199+
200+
DBSCANSearchParameters {
201+
distance: vec![default_params.distance],
202+
min_samples: vec![default_params.min_samples],
203+
eps: vec![default_params.eps],
204+
algorithm: vec![default_params.algorithm],
205+
}
206+
}
207+
}
208+
112209
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
113210
fn eq(&self, other: &Self) -> bool {
114211
self.cluster_labels.len() == other.cluster_labels.len()
@@ -268,6 +365,29 @@ mod tests {
268365
#[cfg(feature = "serde")]
269366
use crate::math::distance::euclidian::Euclidian;
270367

368+
#[test]
369+
fn search_parameters() {
370+
let parameters = DBSCANSearchParameters {
371+
min_samples: vec![10, 100],
372+
eps: vec![1., 2.],
373+
..Default::default()
374+
};
375+
let mut iter = parameters.into_iter();
376+
let next = iter.next().unwrap();
377+
assert_eq!(next.min_samples, 10);
378+
assert_eq!(next.eps, 1.);
379+
let next = iter.next().unwrap();
380+
assert_eq!(next.min_samples, 100);
381+
assert_eq!(next.eps, 1.);
382+
let next = iter.next().unwrap();
383+
assert_eq!(next.min_samples, 10);
384+
assert_eq!(next.eps, 2.);
385+
let next = iter.next().unwrap();
386+
assert_eq!(next.min_samples, 100);
387+
assert_eq!(next.eps, 2.);
388+
assert!(iter.next().is_none());
389+
}
390+
271391
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
272392
#[test]
273393
fn fit_predict_dbscan() {

src/cluster/kmeans.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
132132
}
133133
}
134134

135+
/// KMeans grid search parameters
136+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137+
#[derive(Debug, Clone)]
138+
pub struct KMeansSearchParameters {
139+
/// Number of clusters.
140+
pub k: Vec<usize>,
141+
/// Maximum number of iterations of the k-means algorithm for a single run.
142+
pub max_iter: Vec<usize>,
143+
}
144+
145+
/// KMeans grid search iterator
146+
pub struct KMeansSearchParametersIterator {
147+
kmeans_search_parameters: KMeansSearchParameters,
148+
current_k: usize,
149+
current_max_iter: usize,
150+
}
151+
152+
impl IntoIterator for KMeansSearchParameters {
153+
type Item = KMeansParameters;
154+
type IntoIter = KMeansSearchParametersIterator;
155+
156+
fn into_iter(self) -> Self::IntoIter {
157+
KMeansSearchParametersIterator {
158+
kmeans_search_parameters: self,
159+
current_k: 0,
160+
current_max_iter: 0,
161+
}
162+
}
163+
}
164+
165+
impl Iterator for KMeansSearchParametersIterator {
166+
type Item = KMeansParameters;
167+
168+
fn next(&mut self) -> Option<Self::Item> {
169+
if self.current_k == self.kmeans_search_parameters.k.len()
170+
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
171+
{
172+
return None;
173+
}
174+
175+
let next = KMeansParameters {
176+
k: self.kmeans_search_parameters.k[self.current_k],
177+
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
178+
};
179+
180+
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
181+
self.current_k += 1;
182+
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
183+
self.current_k = 0;
184+
self.current_max_iter += 1;
185+
} else {
186+
self.current_k += 1;
187+
self.current_max_iter += 1;
188+
}
189+
190+
Some(next)
191+
}
192+
}
193+
194+
impl Default for KMeansSearchParameters {
195+
fn default() -> Self {
196+
let default_params = KMeansParameters::default();
197+
198+
KMeansSearchParameters {
199+
k: vec![default_params.k],
200+
max_iter: vec![default_params.max_iter],
201+
}
202+
}
203+
}
204+
135205
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
136206
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
137207
KMeans::fit(x, parameters)
@@ -313,6 +383,29 @@ mod tests {
313383
);
314384
}
315385

386+
#[test]
387+
fn search_parameters() {
388+
let parameters = KMeansSearchParameters {
389+
k: vec![2, 4],
390+
max_iter: vec![10, 100],
391+
..Default::default()
392+
};
393+
let mut iter = parameters.into_iter();
394+
let next = iter.next().unwrap();
395+
assert_eq!(next.k, 2);
396+
assert_eq!(next.max_iter, 10);
397+
let next = iter.next().unwrap();
398+
assert_eq!(next.k, 4);
399+
assert_eq!(next.max_iter, 10);
400+
let next = iter.next().unwrap();
401+
assert_eq!(next.k, 2);
402+
assert_eq!(next.max_iter, 100);
403+
let next = iter.next().unwrap();
404+
assert_eq!(next.k, 4);
405+
assert_eq!(next.max_iter, 100);
406+
assert!(iter.next().is_none());
407+
}
408+
316409
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
317410
#[test]
318411
fn fit_predict_iris() {

src/decomposition/pca.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,81 @@ impl Default for PCAParameters {
116116
}
117117
}
118118

119+
/// PCA grid search parameters
120+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
121+
#[derive(Debug, Clone)]
122+
pub struct PCASearchParameters {
123+
/// Number of components to keep.
124+
pub n_components: Vec<usize>,
125+
/// By default, covariance matrix is used to compute principal components.
126+
/// Enable this flag if you want to use correlation matrix instead.
127+
pub use_correlation_matrix: Vec<bool>,
128+
}
129+
130+
/// PCA grid search iterator
131+
pub struct PCASearchParametersIterator {
132+
pca_search_parameters: PCASearchParameters,
133+
current_k: usize,
134+
current_use_correlation_matrix: usize,
135+
}
136+
137+
impl IntoIterator for PCASearchParameters {
138+
type Item = PCAParameters;
139+
type IntoIter = PCASearchParametersIterator;
140+
141+
fn into_iter(self) -> Self::IntoIter {
142+
PCASearchParametersIterator {
143+
pca_search_parameters: self,
144+
current_k: 0,
145+
current_use_correlation_matrix: 0,
146+
}
147+
}
148+
}
149+
150+
impl Iterator for PCASearchParametersIterator {
151+
type Item = PCAParameters;
152+
153+
fn next(&mut self) -> Option<Self::Item> {
154+
if self.current_k == self.pca_search_parameters.n_components.len()
155+
&& self.current_use_correlation_matrix
156+
== self.pca_search_parameters.use_correlation_matrix.len()
157+
{
158+
return None;
159+
}
160+
161+
let next = PCAParameters {
162+
n_components: self.pca_search_parameters.n_components[self.current_k],
163+
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
164+
[self.current_use_correlation_matrix],
165+
};
166+
167+
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
168+
self.current_k += 1;
169+
} else if self.current_use_correlation_matrix + 1
170+
< self.pca_search_parameters.use_correlation_matrix.len()
171+
{
172+
self.current_k = 0;
173+
self.current_use_correlation_matrix += 1;
174+
} else {
175+
self.current_k += 1;
176+
self.current_use_correlation_matrix += 1;
177+
}
178+
179+
Some(next)
180+
}
181+
}
182+
183+
impl Default for PCASearchParameters {
184+
fn default() -> Self {
185+
let default_params = PCAParameters::default();
186+
187+
PCASearchParameters {
188+
n_components: vec![default_params.n_components],
189+
use_correlation_matrix: vec![default_params.use_correlation_matrix],
190+
}
191+
}
192+
}
193+
119194
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
120195
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
121196
PCA::fit(x, parameters)
@@ -271,6 +346,29 @@ mod tests {
271346
use super::*;
272347
use crate::linalg::naive::dense_matrix::*;
273348

349+
#[test]
350+
fn search_parameters() {
351+
let parameters = PCASearchParameters {
352+
n_components: vec![2, 4],
353+
use_correlation_matrix: vec![true, false],
354+
..Default::default()
355+
};
356+
let mut iter = parameters.into_iter();
357+
let next = iter.next().unwrap();
358+
assert_eq!(next.n_components, 2);
359+
assert_eq!(next.use_correlation_matrix, true);
360+
let next = iter.next().unwrap();
361+
assert_eq!(next.n_components, 4);
362+
assert_eq!(next.use_correlation_matrix, true);
363+
let next = iter.next().unwrap();
364+
assert_eq!(next.n_components, 2);
365+
assert_eq!(next.use_correlation_matrix, false);
366+
let next = iter.next().unwrap();
367+
assert_eq!(next.n_components, 4);
368+
assert_eq!(next.use_correlation_matrix, false);
369+
assert!(iter.next().is_none());
370+
}
371+
274372
fn us_arrests_data() -> DenseMatrix<f64> {
275373
DenseMatrix::from_2d_array(&[
276374
&[13.2, 236.0, 58.0, 21.2],

0 commit comments

Comments
 (0)