Skip to content

Commit a37b552

Browse files
committed
Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
1 parent 55e1158 commit a37b552

File tree

14 files changed

+139
-64
lines changed

14 files changed

+139
-64
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@ name: CI
22

33
on:
44
push:
5-
branches: [ main, development ]
5+
branches: [main, development]
66
pull_request:
7-
branches: [ development ]
7+
branches: [development]
88

99
jobs:
1010
tests:
1111
runs-on: "${{ matrix.platform.os }}-latest"
1212
strategy:
1313
matrix:
14-
platform: [
15-
{ os: "windows", target: "x86_64-pc-windows-msvc" },
16-
{ os: "windows", target: "i686-pc-windows-msvc" },
17-
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
18-
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
19-
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
20-
{ os: "macos", target: "aarch64-apple-darwin" },
21-
]
14+
platform:
15+
[
16+
{ os: "windows", target: "x86_64-pc-windows-msvc" },
17+
{ os: "windows", target: "i686-pc-windows-msvc" },
18+
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
19+
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
20+
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
21+
{ os: "macos", target: "aarch64-apple-darwin" },
22+
]
2223
env:
2324
TZ: "/usr/share/zoneinfo/your/location"
2425
steps:
@@ -40,7 +41,7 @@ jobs:
4041
default: true
4142
- name: Install test runner for wasm
4243
if: matrix.platform.target == 'wasm32-unknown-unknown'
43-
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
44+
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
4445
- name: Stable Build
4546
uses: actions-rs/cargo@v1
4647
with:

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
## Added
10+
- Seeds to multiple algorithims that depend on random number generation.
11+
- Added feature `js` to use WASM in browser
12+
13+
## BREAKING CHANGE
14+
- Added a new parameter to `train_test_split` to define the seed.
15+
16+
## [0.2.1] - 2022-05-10
17+
918
## Added
1019
- L2 regularization penalty to the Logistic Regression
1120
- Getters for the naive bayes structs

Cargo.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,25 @@ categories = ["science"]
1616
default = ["datasets"]
1717
ndarray-bindings = ["ndarray"]
1818
nalgebra-bindings = ["nalgebra"]
19-
datasets = ["rand_distr"]
19+
datasets = ["rand_distr", "std"]
2020
fp_bench = ["itertools"]
21+
std = ["rand/std", "rand/std_rng"]
22+
# wasm32 only
23+
js = ["getrandom/js"]
2124

2225
[dependencies]
2326
ndarray = { version = "0.15", optional = true }
2427
nalgebra = { version = "0.31", optional = true }
2528
num-traits = "0.2"
2629
num = "0.4"
27-
rand = "0.8"
30+
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
2831
rand_distr = { version = "0.4", optional = true }
2932
serde = { version = "1", features = ["derive"], optional = true }
3033
itertools = { version = "0.10.3", optional = true }
34+
cfg-if = "1.0.0"
3135

3236
[target.'cfg(target_arch = "wasm32")'.dependencies]
33-
getrandom = { version = "0.2", features = ["js"] }
37+
getrandom = { version = "0.2", optional = true }
3438

3539
[dev-dependencies]
3640
smartcore = { path = ".", features = ["fp_bench"] }

src/cluster/kmeans.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
5353
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
5454
55-
use rand::Rng;
5655
use std::fmt::Debug;
5756
use std::iter::Sum;
5857

58+
use ::rand::Rng;
5959
#[cfg(feature = "serde")]
6060
use serde::{Deserialize, Serialize};
6161

@@ -65,6 +65,7 @@ use crate::error::Failed;
6565
use crate::linalg::Matrix;
6666
use crate::math::distance::euclidian::*;
6767
use crate::math::num::RealNumber;
68+
use crate::rand::get_rng_impl;
6869

6970
/// K-Means clustering algorithm
7071
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -108,6 +109,9 @@ pub struct KMeansParameters {
108109
pub k: usize,
109110
/// Maximum number of iterations of the k-means algorithm for a single run.
110111
pub max_iter: usize,
112+
/// Determines random number generation for centroid initialization.
113+
/// Use an int to make the randomness deterministic
114+
pub seed: Option<u64>,
111115
}
112116

113117
impl KMeansParameters {
@@ -128,6 +132,7 @@ impl Default for KMeansParameters {
128132
KMeansParameters {
129133
k: 2,
130134
max_iter: 100,
135+
seed: None,
131136
}
132137
}
133138
}
@@ -238,7 +243,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
238243
let (n, d) = data.shape();
239244

240245
let mut distortion = T::max_value();
241-
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
246+
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
242247
let mut size = vec![0; parameters.k];
243248
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
244249

@@ -311,8 +316,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
311316
Ok(result.to_row_vector())
312317
}
313318

314-
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
315-
let mut rng = rand::thread_rng();
319+
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
320+
let mut rng = get_rng_impl(seed);
316321
let (n, m) = data.shape();
317322
let mut y = vec![0; n];
318323
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));

src/ensemble/random_forest_classifier.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
//!
4646
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
4747
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
48-
use rand::rngs::StdRng;
49-
use rand::{Rng, SeedableRng};
48+
use rand::Rng;
49+
5050
use std::default::Default;
5151
use std::fmt::Debug;
5252

@@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator};
5757
use crate::error::{Failed, FailedError};
5858
use crate::linalg::Matrix;
5959
use crate::math::num::RealNumber;
60+
use crate::rand::get_rng_impl;
6061
use crate::tree::decision_tree_classifier::{
6162
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
6263
};
@@ -441,7 +442,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
441442
.unwrap()
442443
});
443444

444-
let mut rng = StdRng::seed_from_u64(parameters.seed);
445+
let mut rng = get_rng_impl(Some(parameters.seed));
445446
let classes = y_m.unique();
446447
let k = classes.len();
447448
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
@@ -462,9 +463,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
462463
max_depth: parameters.max_depth,
463464
min_samples_leaf: parameters.min_samples_leaf,
464465
min_samples_split: parameters.min_samples_split,
466+
seed: Some(parameters.seed),
465467
};
466-
let tree =
467-
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
468+
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
468469
trees.push(tree);
469470
}
470471

src/ensemble/random_forest_regressor.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
4444
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
4545
46-
use rand::rngs::StdRng;
47-
use rand::{Rng, SeedableRng};
46+
use rand::Rng;
47+
4848
use std::default::Default;
4949
use std::fmt::Debug;
5050

@@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
5555
use crate::error::{Failed, FailedError};
5656
use crate::linalg::Matrix;
5757
use crate::math::num::RealNumber;
58+
use crate::rand::get_rng_impl;
5859
use crate::tree::decision_tree_regressor::{
5960
DecisionTreeRegressor, DecisionTreeRegressorParameters,
6061
};
@@ -376,7 +377,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
376377
.m
377378
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
378379

379-
let mut rng = StdRng::seed_from_u64(parameters.seed);
380+
let mut rng = get_rng_impl(Some(parameters.seed));
380381
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
381382

382383
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
@@ -393,9 +394,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
393394
max_depth: parameters.max_depth,
394395
min_samples_leaf: parameters.min_samples_leaf,
395396
min_samples_split: parameters.min_samples_split,
397+
seed: Some(parameters.seed),
396398
};
397-
let tree =
398-
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
399+
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
399400
trees.push(tree);
400401
}
401402

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,5 @@ pub mod readers;
101101
pub mod svm;
102102
/// Supervised tree-based learning methods
103103
pub mod tree;
104+
105+
pub(crate) mod rand;

src/math/num.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use std::iter::{Product, Sum};
99
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
1010
use std::str::FromStr;
1111

12+
use crate::rand::get_rng_impl;
13+
1214
/// Defines real number
1315
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
1416
pub trait RealNumber:
@@ -79,7 +81,7 @@ impl RealNumber for f64 {
7981
}
8082

8183
fn rand() -> f64 {
82-
let mut rng = rand::thread_rng();
84+
let mut rng = get_rng_impl(None);
8385
rng.gen()
8486
}
8587

@@ -124,7 +126,7 @@ impl RealNumber for f32 {
124126
}
125127

126128
fn rand() -> f32 {
127-
let mut rng = rand::thread_rng();
129+
let mut rng = get_rng_impl(None);
128130
rng.gen()
129131
}
130132

src/model_selection/kfold.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
use crate::linalg::Matrix;
66
use crate::math::num::RealNumber;
77
use crate::model_selection::BaseKFold;
8+
use crate::rand::get_rng_impl;
89
use rand::seq::SliceRandom;
9-
use rand::thread_rng;
1010

1111
/// K-Folds cross-validator
1212
pub struct KFold {
1313
/// Number of folds. Must be at least 2.
1414
pub n_splits: usize, // cannot exceed std::usize::MAX
1515
/// Whether to shuffle the data before splitting into batches
1616
pub shuffle: bool,
17+
/// When shuffle is True, seed affects the ordering of the indices.
18+
/// Which controls the randomness of each fold
19+
pub seed: Option<u64>,
1720
}
1821

1922
impl KFold {
@@ -23,8 +26,10 @@ impl KFold {
2326

2427
// initialise indices
2528
let mut indices: Vec<usize> = (0..n_samples).collect();
29+
let mut rng = get_rng_impl(self.seed);
30+
2631
if self.shuffle {
27-
indices.shuffle(&mut thread_rng());
32+
indices.shuffle(&mut rng);
2833
}
2934
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
3035
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
@@ -66,6 +71,7 @@ impl Default for KFold {
6671
KFold {
6772
n_splits: 3,
6873
shuffle: true,
74+
seed: None,
6975
}
7076
}
7177
}
@@ -81,6 +87,12 @@ impl KFold {
8187
self.shuffle = shuffle;
8288
self
8389
}
90+
91+
/// When shuffle is True, random_state affects the ordering of the indices.
92+
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
93+
self.seed = seed;
94+
self
95+
}
8496
}
8597

8698
/// An iterator over indices that split data into training and test set.
@@ -150,6 +162,7 @@ mod tests {
150162
let k = KFold {
151163
n_splits: 3,
152164
shuffle: false,
165+
seed: None,
153166
};
154167
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
155168
let test_indices = k.test_indices(&x);
@@ -165,6 +178,7 @@ mod tests {
165178
let k = KFold {
166179
n_splits: 3,
167180
shuffle: false,
181+
seed: None,
168182
};
169183
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
170184
let test_indices = k.test_indices(&x);
@@ -180,6 +194,7 @@ mod tests {
180194
let k = KFold {
181195
n_splits: 2,
182196
shuffle: false,
197+
seed: None,
183198
};
184199
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
185200
let test_masks = k.test_masks(&x);
@@ -206,6 +221,7 @@ mod tests {
206221
let k = KFold {
207222
n_splits: 2,
208223
shuffle: false,
224+
seed: None,
209225
};
210226
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
211227
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
@@ -238,6 +254,7 @@ mod tests {
238254
let k = KFold {
239255
n_splits: 3,
240256
shuffle: false,
257+
seed: None,
241258
};
242259
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
243260
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![

0 commit comments

Comments
 (0)