Skip to content

Commit 4d5f64c

Browse files
authored
Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for `StandardScaler`. * Add relevant unit test. Signed-off-by: Christos Katsakioris <ckatsak@gmail.com> Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
1 parent d305406 commit 4d5f64c

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/preprocessing/numerical.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
3232
use crate::linalg::Matrix;
3333
use crate::math::num::RealNumber;
3434

35+
#[cfg(feature = "serde")]
36+
use serde::{Deserialize, Serialize};
37+
3538
/// Configure Behaviour of `StandardScaler`.
39+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
3640
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
3741
pub struct StandardScalerParameters {
3842
/// Optionaly adjust mean to be zero.
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
5458
/// deviation of one. This can improve model training for
5559
/// scaling sensitive models like neural network or nearest
5660
/// neighbors based models.
61+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
5762
#[derive(Clone, Debug, Default, Eq, PartialEq)]
5863
pub struct StandardScaler<T: RealNumber> {
5964
means: Vec<T>,
@@ -400,5 +405,43 @@ mod tests {
400405
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
401406
)
402407
}
408+
409+
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
410+
/// serialized and deserialized.
411+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412+
#[test]
413+
#[cfg(feature = "serde")]
414+
fn serde_fit_for_random_values() {
415+
let fitted_scaler = StandardScaler::fit(
416+
&DenseMatrix::from_2d_array(&[
417+
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
418+
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
419+
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
420+
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
421+
]),
422+
StandardScalerParameters::default(),
423+
)
424+
.unwrap();
425+
426+
let deserialized_scaler: StandardScaler<f64> =
427+
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
428+
429+
assert_eq!(
430+
deserialized_scaler.means,
431+
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
432+
);
433+
434+
assert!(
435+
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
436+
&DenseMatrix::from_2d_array(&[&[
437+
0.29426447500954,
438+
0.16758497615485,
439+
0.20820945786863,
440+
0.23329718831165
441+
],]),
442+
0.00000000000001
443+
)
444+
)
445+
}
403446
}
404447
}

0 commit comments

Comments
 (0)