Skip to content

Commit 4685fc7

Browse files
authored
grid search (#154)
* grid search draft * hyperparam search for linear estimators
1 parent 2e5f88f commit 4685fc7

File tree

7 files changed

+649
-11
lines changed

7 files changed

+649
-11
lines changed

src/linear/elastic_net.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,121 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
135135
}
136136
}
137137

138+
/// ElasticNet grid search parameters
139+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140+
#[derive(Debug, Clone)]
141+
pub struct ElasticNetSearchParameters<T: RealNumber> {
142+
/// Regularization parameter.
143+
pub alpha: Vec<T>,
144+
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
145+
/// For l1_ratio = 0 the penalty is an L2 penalty.
146+
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
147+
pub l1_ratio: Vec<T>,
148+
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
149+
pub normalize: Vec<bool>,
150+
/// The tolerance for the optimization
151+
pub tol: Vec<T>,
152+
/// The maximum number of iterations
153+
pub max_iter: Vec<usize>,
154+
}
155+
156+
/// ElasticNet grid search iterator
157+
pub struct ElasticNetSearchParametersIterator<T: RealNumber> {
158+
lasso_regression_search_parameters: ElasticNetSearchParameters<T>,
159+
current_alpha: usize,
160+
current_l1_ratio: usize,
161+
current_normalize: usize,
162+
current_tol: usize,
163+
current_max_iter: usize,
164+
}
165+
166+
impl<T: RealNumber> IntoIterator for ElasticNetSearchParameters<T> {
167+
type Item = ElasticNetParameters<T>;
168+
type IntoIter = ElasticNetSearchParametersIterator<T>;
169+
170+
fn into_iter(self) -> Self::IntoIter {
171+
ElasticNetSearchParametersIterator {
172+
lasso_regression_search_parameters: self,
173+
current_alpha: 0,
174+
current_l1_ratio: 0,
175+
current_normalize: 0,
176+
current_tol: 0,
177+
current_max_iter: 0,
178+
}
179+
}
180+
}
181+
182+
impl<T: RealNumber> Iterator for ElasticNetSearchParametersIterator<T> {
183+
type Item = ElasticNetParameters<T>;
184+
185+
fn next(&mut self) -> Option<Self::Item> {
186+
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
187+
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
188+
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
189+
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
190+
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
191+
{
192+
return None;
193+
}
194+
195+
let next = ElasticNetParameters {
196+
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
197+
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
198+
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
199+
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
200+
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
201+
};
202+
203+
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
204+
self.current_alpha += 1;
205+
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
206+
{
207+
self.current_alpha = 0;
208+
self.current_l1_ratio += 1;
209+
} else if self.current_normalize + 1
210+
< self.lasso_regression_search_parameters.normalize.len()
211+
{
212+
self.current_alpha = 0;
213+
self.current_l1_ratio = 0;
214+
self.current_normalize += 1;
215+
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
216+
self.current_alpha = 0;
217+
self.current_l1_ratio = 0;
218+
self.current_normalize = 0;
219+
self.current_tol += 1;
220+
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
221+
{
222+
self.current_alpha = 0;
223+
self.current_l1_ratio = 0;
224+
self.current_normalize = 0;
225+
self.current_tol = 0;
226+
self.current_max_iter += 1;
227+
} else {
228+
self.current_alpha += 1;
229+
self.current_l1_ratio += 1;
230+
self.current_normalize += 1;
231+
self.current_tol += 1;
232+
self.current_max_iter += 1;
233+
}
234+
235+
Some(next)
236+
}
237+
}
238+
239+
impl<T: RealNumber> Default for ElasticNetSearchParameters<T> {
240+
fn default() -> Self {
241+
let default_params = ElasticNetParameters::default();
242+
243+
ElasticNetSearchParameters {
244+
alpha: vec![default_params.alpha],
245+
l1_ratio: vec![default_params.l1_ratio],
246+
normalize: vec![default_params.normalize],
247+
tol: vec![default_params.tol],
248+
max_iter: vec![default_params.max_iter],
249+
}
250+
}
251+
}
252+
138253
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
139254
fn eq(&self, other: &Self) -> bool {
140255
self.coefficients == other.coefficients
@@ -291,6 +406,29 @@ mod tests {
291406
use crate::linalg::naive::dense_matrix::*;
292407
use crate::metrics::mean_absolute_error;
293408

409+
#[test]
410+
fn search_parameters() {
411+
let parameters = ElasticNetSearchParameters {
412+
alpha: vec![0., 1.],
413+
max_iter: vec![10, 100],
414+
..Default::default()
415+
};
416+
let mut iter = parameters.into_iter();
417+
let next = iter.next().unwrap();
418+
assert_eq!(next.alpha, 0.);
419+
assert_eq!(next.max_iter, 10);
420+
let next = iter.next().unwrap();
421+
assert_eq!(next.alpha, 1.);
422+
assert_eq!(next.max_iter, 10);
423+
let next = iter.next().unwrap();
424+
assert_eq!(next.alpha, 0.);
425+
assert_eq!(next.max_iter, 100);
426+
let next = iter.next().unwrap();
427+
assert_eq!(next.alpha, 1.);
428+
assert_eq!(next.max_iter, 100);
429+
assert!(iter.next().is_none());
430+
}
431+
294432
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
295433
#[test]
296434
fn elasticnet_longley() {

src/linear/lasso.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,105 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
112112
}
113113
}
114114

115+
/// Lasso grid search parameters
116+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
117+
#[derive(Debug, Clone)]
118+
pub struct LassoSearchParameters<T: RealNumber> {
119+
/// Controls the strength of the penalty to the loss function.
120+
pub alpha: Vec<T>,
121+
/// If true the regressors X will be normalized before regression
122+
/// by subtracting the mean and dividing by the standard deviation.
123+
pub normalize: Vec<bool>,
124+
/// The tolerance for the optimization
125+
pub tol: Vec<T>,
126+
/// The maximum number of iterations
127+
pub max_iter: Vec<usize>,
128+
}
129+
130+
/// Lasso grid search iterator
131+
pub struct LassoSearchParametersIterator<T: RealNumber> {
132+
lasso_regression_search_parameters: LassoSearchParameters<T>,
133+
current_alpha: usize,
134+
current_normalize: usize,
135+
current_tol: usize,
136+
current_max_iter: usize,
137+
}
138+
139+
impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
140+
type Item = LassoParameters<T>;
141+
type IntoIter = LassoSearchParametersIterator<T>;
142+
143+
fn into_iter(self) -> Self::IntoIter {
144+
LassoSearchParametersIterator {
145+
lasso_regression_search_parameters: self,
146+
current_alpha: 0,
147+
current_normalize: 0,
148+
current_tol: 0,
149+
current_max_iter: 0,
150+
}
151+
}
152+
}
153+
154+
impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
155+
type Item = LassoParameters<T>;
156+
157+
fn next(&mut self) -> Option<Self::Item> {
158+
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
159+
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
160+
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
161+
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
162+
{
163+
return None;
164+
}
165+
166+
let next = LassoParameters {
167+
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
168+
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
169+
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
170+
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
171+
};
172+
173+
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
174+
self.current_alpha += 1;
175+
} else if self.current_normalize + 1
176+
< self.lasso_regression_search_parameters.normalize.len()
177+
{
178+
self.current_alpha = 0;
179+
self.current_normalize += 1;
180+
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
181+
self.current_alpha = 0;
182+
self.current_normalize = 0;
183+
self.current_tol += 1;
184+
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
185+
{
186+
self.current_alpha = 0;
187+
self.current_normalize = 0;
188+
self.current_tol = 0;
189+
self.current_max_iter += 1;
190+
} else {
191+
self.current_alpha += 1;
192+
self.current_normalize += 1;
193+
self.current_tol += 1;
194+
self.current_max_iter += 1;
195+
}
196+
197+
Some(next)
198+
}
199+
}
200+
201+
impl<T: RealNumber> Default for LassoSearchParameters<T> {
202+
fn default() -> Self {
203+
let default_params = LassoParameters::default();
204+
205+
LassoSearchParameters {
206+
alpha: vec![default_params.alpha],
207+
normalize: vec![default_params.normalize],
208+
tol: vec![default_params.tol],
209+
max_iter: vec![default_params.max_iter],
210+
}
211+
}
212+
}
213+
115214
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
116215
/// Fits Lasso regression to your data.
117216
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
@@ -226,6 +325,29 @@ mod tests {
226325
use crate::linalg::naive::dense_matrix::*;
227326
use crate::metrics::mean_absolute_error;
228327

328+
#[test]
329+
fn search_parameters() {
330+
let parameters = LassoSearchParameters {
331+
alpha: vec![0., 1.],
332+
max_iter: vec![10, 100],
333+
..Default::default()
334+
};
335+
let mut iter = parameters.into_iter();
336+
let next = iter.next().unwrap();
337+
assert_eq!(next.alpha, 0.);
338+
assert_eq!(next.max_iter, 10);
339+
let next = iter.next().unwrap();
340+
assert_eq!(next.alpha, 1.);
341+
assert_eq!(next.max_iter, 10);
342+
let next = iter.next().unwrap();
343+
assert_eq!(next.alpha, 0.);
344+
assert_eq!(next.max_iter, 100);
345+
let next = iter.next().unwrap();
346+
assert_eq!(next.alpha, 1.);
347+
assert_eq!(next.max_iter, 100);
348+
assert!(iter.next().is_none());
349+
}
350+
229351
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
230352
#[test]
231353
fn lasso_fit_predict() {

src/linear/linear_regression.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ use crate::linalg::Matrix;
7171
use crate::math::num::RealNumber;
7272

7373
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74-
#[derive(Debug, Clone)]
74+
#[derive(Debug, Clone, Eq, PartialEq)]
7575
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
7676
pub enum LinearRegressionSolverName {
7777
/// QR decomposition, see [QR](../../linalg/qr/index.html)
@@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
113113
}
114114
}
115115

116+
/// Linear Regression grid search parameters
117+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
118+
#[derive(Debug, Clone)]
119+
pub struct LinearRegressionSearchParameters {
120+
/// Solver to use for estimation of regression coefficients.
121+
pub solver: Vec<LinearRegressionSolverName>,
122+
}
123+
124+
/// Linear Regression grid search iterator
125+
pub struct LinearRegressionSearchParametersIterator {
126+
linear_regression_search_parameters: LinearRegressionSearchParameters,
127+
current_solver: usize,
128+
}
129+
130+
impl IntoIterator for LinearRegressionSearchParameters {
131+
type Item = LinearRegressionParameters;
132+
type IntoIter = LinearRegressionSearchParametersIterator;
133+
134+
fn into_iter(self) -> Self::IntoIter {
135+
LinearRegressionSearchParametersIterator {
136+
linear_regression_search_parameters: self,
137+
current_solver: 0,
138+
}
139+
}
140+
}
141+
142+
impl Iterator for LinearRegressionSearchParametersIterator {
143+
type Item = LinearRegressionParameters;
144+
145+
fn next(&mut self) -> Option<Self::Item> {
146+
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
147+
return None;
148+
}
149+
150+
let next = LinearRegressionParameters {
151+
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
152+
};
153+
154+
self.current_solver += 1;
155+
156+
Some(next)
157+
}
158+
}
159+
160+
impl Default for LinearRegressionSearchParameters {
161+
fn default() -> Self {
162+
let default_params = LinearRegressionParameters::default();
163+
164+
LinearRegressionSearchParameters {
165+
solver: vec![default_params.solver],
166+
}
167+
}
168+
}
169+
116170
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
117171
fn eq(&self, other: &Self) -> bool {
118172
self.coefficients == other.coefficients
@@ -200,6 +254,20 @@ mod tests {
200254
use super::*;
201255
use crate::linalg::naive::dense_matrix::*;
202256

257+
#[test]
258+
fn search_parameters() {
259+
let parameters = LinearRegressionSearchParameters {
260+
solver: vec![
261+
LinearRegressionSolverName::QR,
262+
LinearRegressionSolverName::SVD,
263+
],
264+
};
265+
let mut iter = parameters.into_iter();
266+
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
267+
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
268+
assert!(iter.next().is_none());
269+
}
270+
203271
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
204272
#[test]
205273
fn ols_fit_predict() {

0 commit comments

Comments
 (0)