Skip to content

Commit 05dfffa

Browse files
montanalowmorenol
authored andcommitted
add seed param to search params (#168)
1 parent a37b552 commit 05dfffa

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

src/cluster/kmeans.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,17 @@ pub struct KMeansSearchParameters {
145145
pub k: Vec<usize>,
146146
/// Maximum number of iterations of the k-means algorithm for a single run.
147147
pub max_iter: Vec<usize>,
148+
/// Determines random number generation for centroid initialization.
149+
/// Use an int to make the randomness deterministic
150+
pub seed: Vec<Option<u64>>,
148151
}
149152

150153
/// KMeans grid search iterator
151154
pub struct KMeansSearchParametersIterator {
152155
kmeans_search_parameters: KMeansSearchParameters,
153156
current_k: usize,
154157
current_max_iter: usize,
158+
current_seed: usize,
155159
}
156160

157161
impl IntoIterator for KMeansSearchParameters {
@@ -163,6 +167,7 @@ impl IntoIterator for KMeansSearchParameters {
163167
kmeans_search_parameters: self,
164168
current_k: 0,
165169
current_max_iter: 0,
170+
current_seed: 0,
166171
}
167172
}
168173
}
@@ -173,23 +178,30 @@ impl Iterator for KMeansSearchParametersIterator {
173178
fn next(&mut self) -> Option<Self::Item> {
174179
if self.current_k == self.kmeans_search_parameters.k.len()
175180
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
181+
&& self.current_seed == self.kmeans_search_parameters.seed.len()
176182
{
177183
return None;
178184
}
179185

180186
let next = KMeansParameters {
181187
k: self.kmeans_search_parameters.k[self.current_k],
182188
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
189+
seed: self.kmeans_search_parameters.seed[self.current_seed],
183190
};
184191

185192
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
186193
self.current_k += 1;
187194
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
188195
self.current_k = 0;
189196
self.current_max_iter += 1;
197+
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
198+
self.current_k = 0;
199+
self.current_max_iter = 0;
200+
self.current_seed += 1;
190201
} else {
191202
self.current_k += 1;
192203
self.current_max_iter += 1;
204+
self.current_seed += 1;
193205
}
194206

195207
Some(next)
@@ -203,6 +215,7 @@ impl Default for KMeansSearchParameters {
203215
KMeansSearchParameters {
204216
k: vec![default_params.k],
205217
max_iter: vec![default_params.max_iter],
218+
seed: vec![default_params.seed],
206219
}
207220
}
208221
}

src/svm/svc.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowV
119119
pub kernel: Vec<K>,
120120
/// Unused parameter.
121121
m: PhantomData<M>,
122+
/// Controls the pseudo random number generation for shuffling the data for probability estimates
123+
seed: Vec<Option<u64>>,
122124
}
123125

124126
/// SVC grid search iterator
@@ -128,6 +130,7 @@ pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T,
128130
current_c: usize,
129131
current_tol: usize,
130132
current_kernel: usize,
133+
current_seed: usize,
131134
}
132135

133136
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
@@ -143,6 +146,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
143146
current_c: 0,
144147
current_tol: 0,
145148
current_kernel: 0,
149+
current_seed: 0,
146150
}
147151
}
148152
}
@@ -157,6 +161,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
157161
&& self.current_c == self.svc_search_parameters.c.len()
158162
&& self.current_tol == self.svc_search_parameters.tol.len()
159163
&& self.current_kernel == self.svc_search_parameters.kernel.len()
164+
&& self.current_seed == self.svc_search_parameters.kernel.len()
160165
{
161166
return None;
162167
}
@@ -167,6 +172,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
167172
tol: self.svc_search_parameters.tol[self.current_tol],
168173
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
169174
m: PhantomData,
175+
seed: self.svc_search_parameters.seed[self.current_seed],
170176
};
171177

172178
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
@@ -183,11 +189,18 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
183189
self.current_c = 0;
184190
self.current_tol = 0;
185191
self.current_kernel += 1;
192+
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
193+
self.current_epoch = 0;
194+
self.current_c = 0;
195+
self.current_tol = 0;
196+
self.current_kernel = 0;
197+
self.current_seed += 1;
186198
} else {
187199
self.current_epoch += 1;
188200
self.current_c += 1;
189201
self.current_tol += 1;
190202
self.current_kernel += 1;
203+
self.current_seed += 1;
191204
}
192205

193206
Some(next)
@@ -204,6 +217,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKe
204217
tol: vec![default_params.tol],
205218
kernel: vec![default_params.kernel],
206219
m: PhantomData,
220+
seed: vec![default_params.seed],
207221
}
208222
}
209223
}

src/tree/decision_tree_classifier.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,21 @@ impl Default for DecisionTreeClassifierParameters {
209209
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
210210
#[derive(Debug, Clone)]
211211
pub struct DecisionTreeClassifierSearchParameters {
212+
#[cfg_attr(feature = "serde", serde(default))]
212213
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
213214
pub criterion: Vec<SplitCriterion>,
215+
#[cfg_attr(feature = "serde", serde(default))]
214216
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
215217
pub max_depth: Vec<Option<u16>>,
218+
#[cfg_attr(feature = "serde", serde(default))]
216219
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
217220
pub min_samples_leaf: Vec<usize>,
221+
#[cfg_attr(feature = "serde", serde(default))]
218222
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
219223
pub min_samples_split: Vec<usize>,
224+
#[cfg_attr(feature = "serde", serde(default))]
225+
/// Controls the randomness of the estimator
226+
pub seed: Vec<Option<u64>>,
220227
}
221228

222229
/// DecisionTreeClassifier grid search iterator
@@ -226,6 +233,7 @@ pub struct DecisionTreeClassifierSearchParametersIterator {
226233
current_max_depth: usize,
227234
current_min_samples_leaf: usize,
228235
current_min_samples_split: usize,
236+
current_seed: usize,
229237
}
230238

231239
impl IntoIterator for DecisionTreeClassifierSearchParameters {
@@ -239,6 +247,7 @@ impl IntoIterator for DecisionTreeClassifierSearchParameters {
239247
current_max_depth: 0,
240248
current_min_samples_leaf: 0,
241249
current_min_samples_split: 0,
250+
current_seed: 0,
242251
}
243252
}
244253
}
@@ -267,6 +276,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
267276
.decision_tree_classifier_search_parameters
268277
.min_samples_split
269278
.len()
279+
&& self.current_seed == self.decision_tree_classifier_search_parameters.seed.len()
270280
{
271281
return None;
272282
}
@@ -283,6 +293,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
283293
min_samples_split: self
284294
.decision_tree_classifier_search_parameters
285295
.min_samples_split[self.current_min_samples_split],
296+
seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed],
286297
};
287298

288299
if self.current_criterion + 1
@@ -319,11 +330,19 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
319330
self.current_max_depth = 0;
320331
self.current_min_samples_leaf = 0;
321332
self.current_min_samples_split += 1;
333+
} else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len()
334+
{
335+
self.current_criterion = 0;
336+
self.current_max_depth = 0;
337+
self.current_min_samples_leaf = 0;
338+
self.current_min_samples_split = 0;
339+
self.current_seed += 1;
322340
} else {
323341
self.current_criterion += 1;
324342
self.current_max_depth += 1;
325343
self.current_min_samples_leaf += 1;
326344
self.current_min_samples_split += 1;
345+
self.current_seed += 1;
327346
}
328347

329348
Some(next)
@@ -339,6 +358,7 @@ impl Default for DecisionTreeClassifierSearchParameters {
339358
max_depth: vec![default_params.max_depth],
340359
min_samples_leaf: vec![default_params.min_samples_leaf],
341360
min_samples_split: vec![default_params.min_samples_split],
361+
seed: vec![default_params.seed],
342362
}
343363
}
344364
}

src/tree/decision_tree_regressor.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ pub struct DecisionTreeRegressorSearchParameters {
148148
pub min_samples_leaf: Vec<usize>,
149149
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
150150
pub min_samples_split: Vec<usize>,
151+
/// Controls the randomness of the estimator
152+
pub seed: Vec<Option<u64>>,
151153
}
152154

153155
/// DecisionTreeRegressor grid search iterator
@@ -156,6 +158,7 @@ pub struct DecisionTreeRegressorSearchParametersIterator {
156158
current_max_depth: usize,
157159
current_min_samples_leaf: usize,
158160
current_min_samples_split: usize,
161+
current_seed: usize,
159162
}
160163

161164
impl IntoIterator for DecisionTreeRegressorSearchParameters {
@@ -168,6 +171,7 @@ impl IntoIterator for DecisionTreeRegressorSearchParameters {
168171
current_max_depth: 0,
169172
current_min_samples_leaf: 0,
170173
current_min_samples_split: 0,
174+
current_seed: 0,
171175
}
172176
}
173177
}
@@ -191,6 +195,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
191195
.decision_tree_regressor_search_parameters
192196
.min_samples_split
193197
.len()
198+
&& self.current_seed == self.decision_tree_regressor_search_parameters.seed.len()
194199
{
195200
return None;
196201
}
@@ -204,6 +209,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
204209
min_samples_split: self
205210
.decision_tree_regressor_search_parameters
206211
.min_samples_split[self.current_min_samples_split],
212+
seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed],
207213
};
208214

209215
if self.current_max_depth + 1
@@ -230,10 +236,17 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
230236
self.current_max_depth = 0;
231237
self.current_min_samples_leaf = 0;
232238
self.current_min_samples_split += 1;
239+
} else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len()
240+
{
241+
self.current_max_depth = 0;
242+
self.current_min_samples_leaf = 0;
243+
self.current_min_samples_split = 0;
244+
self.current_seed += 1;
233245
} else {
234246
self.current_max_depth += 1;
235247
self.current_min_samples_leaf += 1;
236248
self.current_min_samples_split += 1;
249+
self.current_seed += 1;
237250
}
238251

239252
Some(next)
@@ -248,6 +261,7 @@ impl Default for DecisionTreeRegressorSearchParameters {
248261
max_depth: vec![default_params.max_depth],
249262
min_samples_leaf: vec![default_params.min_samples_leaf],
250263
min_samples_split: vec![default_params.min_samples_split],
264+
seed: vec![default_params.seed],
251265
}
252266
}
253267
}

0 commit comments

Comments
 (0)