Skip to content

Commit 764309e

Browse files
authored
make default params available to serde (#167)
* add seed param to search params * make default params available to serde * lints * create defaults for enums * lint
1 parent 403d3f2 commit 764309e

22 files changed

+175
-18
lines changed

src/algorithm/neighbour/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ pub enum KNNAlgorithmName {
5959
CoverTree,
6060
}
6161

62+
impl Default for KNNAlgorithmName {
63+
fn default() -> Self {
64+
KNNAlgorithmName::CoverTree
65+
}
66+
}
67+
6268
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6369
#[derive(Debug)]
6470
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {

src/cluster/dbscan.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,22 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
6565
eps: T,
6666
}
6767

68+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6869
#[derive(Debug, Clone)]
6970
/// DBSCAN clustering algorithm parameters
7071
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
72+
#[cfg_attr(feature = "serde", serde(default))]
7173
/// a function that defines a distance between each pair of point in training data.
7274
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
7375
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
7476
pub distance: D,
77+
#[cfg_attr(feature = "serde", serde(default))]
7578
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
7679
pub min_samples: usize,
80+
#[cfg_attr(feature = "serde", serde(default))]
7781
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
7882
pub eps: T,
83+
#[cfg_attr(feature = "serde", serde(default))]
7984
/// KNN algorithm to use.
8085
pub algorithm: KNNAlgorithmName,
8186
}
@@ -113,14 +118,18 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
113118
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114119
#[derive(Debug, Clone)]
115120
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
121+
#[cfg_attr(feature = "serde", serde(default))]
116122
/// a function that defines a distance between each pair of point in training data.
117123
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
118124
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
119125
pub distance: Vec<D>,
126+
#[cfg_attr(feature = "serde", serde(default))]
120127
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
121128
pub min_samples: Vec<usize>,
129+
#[cfg_attr(feature = "serde", serde(default))]
122130
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
123131
pub eps: Vec<T>,
132+
#[cfg_attr(feature = "serde", serde(default))]
124133
/// KNN algorithm to use.
125134
pub algorithm: Vec<KNNAlgorithmName>,
126135
}
@@ -221,7 +230,7 @@ impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
221230
distance: Distances::euclidian(),
222231
min_samples: 5,
223232
eps: T::half(),
224-
algorithm: KNNAlgorithmName::CoverTree,
233+
algorithm: KNNAlgorithmName::default(),
225234
}
226235
}
227236
}

src/cluster/kmeans.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,17 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
102102
}
103103
}
104104

105+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
105106
#[derive(Debug, Clone)]
106107
/// K-Means clustering algorithm parameters
107108
pub struct KMeansParameters {
109+
#[cfg_attr(feature = "serde", serde(default))]
108110
/// Number of clusters.
109111
pub k: usize,
112+
#[cfg_attr(feature = "serde", serde(default))]
110113
/// Maximum number of iterations of the k-means algorithm for a single run.
111114
pub max_iter: usize,
115+
#[cfg_attr(feature = "serde", serde(default))]
112116
/// Determines random number generation for centroid initialization.
113117
/// Use an int to make the randomness deterministic
114118
pub seed: Option<u64>,
@@ -141,10 +145,13 @@ impl Default for KMeansParameters {
141145
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
142146
#[derive(Debug, Clone)]
143147
pub struct KMeansSearchParameters {
148+
#[cfg_attr(feature = "serde", serde(default))]
144149
/// Number of clusters.
145150
pub k: Vec<usize>,
151+
#[cfg_attr(feature = "serde", serde(default))]
146152
/// Maximum number of iterations of the k-means algorithm for a single run.
147153
pub max_iter: Vec<usize>,
154+
#[cfg_attr(feature = "serde", serde(default))]
148155
/// Determines random number generation for centroid initialization.
149156
/// Use an int to make the randomness deterministic
150157
pub seed: Vec<Option<u64>>,

src/decomposition/pca.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
8383
}
8484
}
8585

86+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8687
#[derive(Debug, Clone)]
8788
/// PCA parameters
8889
pub struct PCAParameters {
90+
#[cfg_attr(feature = "serde", serde(default))]
8991
/// Number of components to keep.
9092
pub n_components: usize,
93+
#[cfg_attr(feature = "serde", serde(default))]
9194
/// By default, covariance matrix is used to compute principal components.
9295
/// Enable this flag if you want to use correlation matrix instead.
9396
pub use_correlation_matrix: bool,
@@ -120,8 +123,10 @@ impl Default for PCAParameters {
120123
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
121124
#[derive(Debug, Clone)]
122125
pub struct PCASearchParameters {
126+
#[cfg_attr(feature = "serde", serde(default))]
123127
/// Number of components to keep.
124128
pub n_components: Vec<usize>,
129+
#[cfg_attr(feature = "serde", serde(default))]
125130
/// By default, covariance matrix is used to compute principal components.
126131
/// Enable this flag if you want to use correlation matrix instead.
127132
pub use_correlation_matrix: Vec<bool>,

src/decomposition/svd.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
6969
}
7070
}
7171

72+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7273
#[derive(Debug, Clone)]
7374
/// SVD parameters
7475
pub struct SVDParameters {
76+
#[cfg_attr(feature = "serde", serde(default))]
7577
/// Number of components to keep.
7678
pub n_components: usize,
7779
}
@@ -94,6 +96,7 @@ impl SVDParameters {
9496
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9597
#[derive(Debug, Clone)]
9698
pub struct SVDSearchParameters {
99+
#[cfg_attr(feature = "serde", serde(default))]
97100
/// Maximum number of iterations of the k-means algorithm for a single run.
98101
pub n_components: Vec<usize>,
99102
}

src/ensemble/random_forest_classifier.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,28 @@ use crate::tree::decision_tree_classifier::{
6767
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6868
#[derive(Debug, Clone)]
6969
pub struct RandomForestClassifierParameters {
70+
#[cfg_attr(feature = "serde", serde(default))]
7071
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
7172
pub criterion: SplitCriterion,
73+
#[cfg_attr(feature = "serde", serde(default))]
7274
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
7375
pub max_depth: Option<u16>,
76+
#[cfg_attr(feature = "serde", serde(default))]
7477
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
7578
pub min_samples_leaf: usize,
79+
#[cfg_attr(feature = "serde", serde(default))]
7680
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
7781
pub min_samples_split: usize,
82+
#[cfg_attr(feature = "serde", serde(default))]
7883
/// The number of trees in the forest.
7984
pub n_trees: u16,
85+
#[cfg_attr(feature = "serde", serde(default))]
8086
/// Number of random sample of predictors to use as split candidates.
8187
pub m: Option<usize>,
88+
#[cfg_attr(feature = "serde", serde(default))]
8289
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
8390
pub keep_samples: bool,
91+
#[cfg_attr(feature = "serde", serde(default))]
8492
/// Seed used for bootstrap sampling and feature selection for each tree.
8593
pub seed: u64,
8694
}
@@ -198,20 +206,28 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestCla
198206
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
199207
#[derive(Debug, Clone)]
200208
pub struct RandomForestClassifierSearchParameters {
209+
#[cfg_attr(feature = "serde", serde(default))]
201210
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
202211
pub criterion: Vec<SplitCriterion>,
212+
#[cfg_attr(feature = "serde", serde(default))]
203213
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
204214
pub max_depth: Vec<Option<u16>>,
215+
#[cfg_attr(feature = "serde", serde(default))]
205216
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
206217
pub min_samples_leaf: Vec<usize>,
218+
#[cfg_attr(feature = "serde", serde(default))]
207219
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
208220
pub min_samples_split: Vec<usize>,
221+
#[cfg_attr(feature = "serde", serde(default))]
209222
/// The number of trees in the forest.
210223
pub n_trees: Vec<u16>,
224+
#[cfg_attr(feature = "serde", serde(default))]
211225
/// Number of random sample of predictors to use as split candidates.
212226
pub m: Vec<Option<usize>>,
227+
#[cfg_attr(feature = "serde", serde(default))]
213228
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
214229
pub keep_samples: Vec<bool>,
230+
#[cfg_attr(feature = "serde", serde(default))]
215231
/// Seed used for bootstrap sampling and feature selection for each tree.
216232
pub seed: Vec<u64>,
217233
}

src/ensemble/random_forest_regressor.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,25 @@ use crate::tree::decision_tree_regressor::{
6565
/// Parameters of the Random Forest Regressor
6666
/// Some parameters here are passed directly into base estimator.
6767
pub struct RandomForestRegressorParameters {
68+
#[cfg_attr(feature = "serde", serde(default))]
6869
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
6970
pub max_depth: Option<u16>,
71+
#[cfg_attr(feature = "serde", serde(default))]
7072
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
7173
pub min_samples_leaf: usize,
74+
#[cfg_attr(feature = "serde", serde(default))]
7275
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
7376
pub min_samples_split: usize,
77+
#[cfg_attr(feature = "serde", serde(default))]
7478
/// The number of trees in the forest.
7579
pub n_trees: usize,
80+
#[cfg_attr(feature = "serde", serde(default))]
7681
/// Number of random sample of predictors to use as split candidates.
7782
pub m: Option<usize>,
83+
#[cfg_attr(feature = "serde", serde(default))]
7884
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
7985
pub keep_samples: bool,
86+
#[cfg_attr(feature = "serde", serde(default))]
8087
/// Seed used for bootstrap sampling and feature selection for each tree.
8188
pub seed: u64,
8289
}
@@ -181,18 +188,25 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestReg
181188
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
182189
#[derive(Debug, Clone)]
183190
pub struct RandomForestRegressorSearchParameters {
191+
#[cfg_attr(feature = "serde", serde(default))]
184192
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
185193
pub max_depth: Vec<Option<u16>>,
194+
#[cfg_attr(feature = "serde", serde(default))]
186195
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
187196
pub min_samples_leaf: Vec<usize>,
197+
#[cfg_attr(feature = "serde", serde(default))]
188198
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
189199
pub min_samples_split: Vec<usize>,
200+
#[cfg_attr(feature = "serde", serde(default))]
190201
/// The number of trees in the forest.
191202
pub n_trees: Vec<usize>,
203+
#[cfg_attr(feature = "serde", serde(default))]
192204
/// Number of random sample of predictors to use as split candidates.
193205
pub m: Vec<Option<usize>>,
206+
#[cfg_attr(feature = "serde", serde(default))]
194207
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
195208
pub keep_samples: Vec<bool>,
209+
#[cfg_attr(feature = "serde", serde(default))]
196210
/// Seed used for bootstrap sampling and feature selection for each tree.
197211
pub seed: Vec<u64>,
198212
}

src/linear/elastic_net.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,21 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
7171
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7272
#[derive(Debug, Clone)]
7373
pub struct ElasticNetParameters<T: RealNumber> {
74+
#[cfg_attr(feature = "serde", serde(default))]
7475
/// Regularization parameter.
7576
pub alpha: T,
77+
#[cfg_attr(feature = "serde", serde(default))]
7678
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
7779
/// For l1_ratio = 0 the penalty is an L2 penalty.
7880
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
7981
pub l1_ratio: T,
82+
#[cfg_attr(feature = "serde", serde(default))]
8083
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
8184
pub normalize: bool,
85+
#[cfg_attr(feature = "serde", serde(default))]
8286
/// The tolerance for the optimization
8387
pub tol: T,
88+
#[cfg_attr(feature = "serde", serde(default))]
8489
/// The maximum number of iterations
8590
pub max_iter: usize,
8691
}
@@ -139,16 +144,21 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
139144
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140145
#[derive(Debug, Clone)]
141146
pub struct ElasticNetSearchParameters<T: RealNumber> {
147+
#[cfg_attr(feature = "serde", serde(default))]
142148
/// Regularization parameter.
143149
pub alpha: Vec<T>,
150+
#[cfg_attr(feature = "serde", serde(default))]
144151
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
145152
/// For l1_ratio = 0 the penalty is an L2 penalty.
146153
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
147154
pub l1_ratio: Vec<T>,
155+
#[cfg_attr(feature = "serde", serde(default))]
148156
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
149157
pub normalize: Vec<bool>,
158+
#[cfg_attr(feature = "serde", serde(default))]
150159
/// The tolerance for the optimization
151160
pub tol: Vec<T>,
161+
#[cfg_attr(feature = "serde", serde(default))]
152162
/// The maximum number of iterations
153163
pub max_iter: Vec<usize>,
154164
}

src/linear/lasso.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ use crate::math::num::RealNumber;
3838
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
3939
#[derive(Debug, Clone)]
4040
pub struct LassoParameters<T: RealNumber> {
41+
#[cfg_attr(feature = "serde", serde(default))]
4142
/// Controls the strength of the penalty to the loss function.
4243
pub alpha: T,
44+
#[cfg_attr(feature = "serde", serde(default))]
4345
/// If true the regressors X will be normalized before regression
4446
/// by subtracting the mean and dividing by the standard deviation.
4547
pub normalize: bool,
48+
#[cfg_attr(feature = "serde", serde(default))]
4649
/// The tolerance for the optimization
4750
pub tol: T,
51+
#[cfg_attr(feature = "serde", serde(default))]
4852
/// The maximum number of iterations
4953
pub max_iter: usize,
5054
}
@@ -116,13 +120,17 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
116120
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
117121
#[derive(Debug, Clone)]
118122
pub struct LassoSearchParameters<T: RealNumber> {
123+
#[cfg_attr(feature = "serde", serde(default))]
119124
/// Controls the strength of the penalty to the loss function.
120125
pub alpha: Vec<T>,
126+
#[cfg_attr(feature = "serde", serde(default))]
121127
/// If true the regressors X will be normalized before regression
122128
/// by subtracting the mean and dividing by the standard deviation.
123129
pub normalize: Vec<bool>,
130+
#[cfg_attr(feature = "serde", serde(default))]
124131
/// The tolerance for the optimization
125132
pub tol: Vec<T>,
133+
#[cfg_attr(feature = "serde", serde(default))]
126134
/// The maximum number of iterations
127135
pub max_iter: Vec<usize>,
128136
}

src/linear/linear_regression.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,21 @@ use crate::linalg::Matrix;
7171
use crate::math::num::RealNumber;
7272

7373
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74-
#[derive(Debug, Clone, Eq, PartialEq)]
74+
#[derive(Debug, Default, Clone, Eq, PartialEq)]
7575
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
7676
pub enum LinearRegressionSolverName {
7777
/// QR decomposition, see [QR](../../linalg/qr/index.html)
7878
QR,
79+
#[default]
7980
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
8081
SVD,
8182
}
8283

8384
/// Linear Regression parameters
8485
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85-
#[derive(Debug, Clone)]
86+
#[derive(Debug, Default, Clone)]
8687
pub struct LinearRegressionParameters {
88+
#[cfg_attr(feature = "serde", serde(default))]
8789
/// Solver to use for estimation of regression coefficients.
8890
pub solver: LinearRegressionSolverName,
8991
}
@@ -105,18 +107,11 @@ impl LinearRegressionParameters {
105107
}
106108
}
107109

108-
impl Default for LinearRegressionParameters {
109-
fn default() -> Self {
110-
LinearRegressionParameters {
111-
solver: LinearRegressionSolverName::SVD,
112-
}
113-
}
114-
}
115-
116110
/// Linear Regression grid search parameters
117111
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
118112
#[derive(Debug, Clone)]
119113
pub struct LinearRegressionSearchParameters {
114+
#[cfg_attr(feature = "serde", serde(default))]
120115
/// Solver to use for estimation of regression coefficients.
121116
pub solver: Vec<LinearRegressionSolverName>,
122117
}
@@ -353,5 +348,9 @@ mod tests {
353348
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
354349

355350
assert_eq!(lr, deserialized_lr);
351+
352+
let default = LinearRegressionParameters::default();
353+
let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
354+
assert_eq!(parameters.solver, default.solver);
356355
}
357356
}

0 commit comments

Comments
 (0)