@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
32
32
use crate :: linalg:: Matrix ;
33
33
use crate :: math:: num:: RealNumber ;
34
34
35
+ #[ cfg( feature = "serde" ) ]
36
+ use serde:: { Deserialize , Serialize } ;
37
+
35
38
/// Configure Behaviour of `StandardScaler`.
39
+ #[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
36
40
#[ derive( Clone , Debug , Copy , Eq , PartialEq ) ]
37
41
pub struct StandardScalerParameters {
38
42
/// Optionaly adjust mean to be zero.
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
54
58
/// deviation of one. This can improve model training for
55
59
/// scaling sensitive models like neural network or nearest
56
60
/// neighbors based models.
61
+ #[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
57
62
#[ derive( Clone , Debug , Default , Eq , PartialEq ) ]
58
63
pub struct StandardScaler < T : RealNumber > {
59
64
means : Vec < T > ,
@@ -400,5 +405,43 @@ mod tests {
400
405
Ok ( DenseMatrix :: from_2d_array( & [ & [ 0.0 , 3.0 ] , & [ 2.0 , 4.0 ] ] ) )
401
406
)
402
407
}
408
+
409
+ /// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
410
+ /// serialized and deserialized.
411
+ #[ cfg_attr( target_arch = "wasm32" , wasm_bindgen_test:: wasm_bindgen_test) ]
412
+ #[ test]
413
+ #[ cfg( feature = "serde" ) ]
414
+ fn serde_fit_for_random_values ( ) {
415
+ let fitted_scaler = StandardScaler :: fit (
416
+ & DenseMatrix :: from_2d_array ( & [
417
+ & [ 0.1004222429 , 0.2194113576 , 0.9310663354 , 0.3313593793 ] ,
418
+ & [ 0.2045493861 , 0.1683865411 , 0.5071506765 , 0.7257355264 ] ,
419
+ & [ 0.5708488802 , 0.1846414616 , 0.9590802982 , 0.5591871046 ] ,
420
+ & [ 0.8387612750 , 0.5754861361 , 0.5537109852 , 0.1077646442 ] ,
421
+ ] ) ,
422
+ StandardScalerParameters :: default ( ) ,
423
+ )
424
+ . unwrap ( ) ;
425
+
426
+ let deserialized_scaler: StandardScaler < f64 > =
427
+ serde_json:: from_str ( & serde_json:: to_string ( & fitted_scaler) . unwrap ( ) ) . unwrap ( ) ;
428
+
429
+ assert_eq ! (
430
+ deserialized_scaler. means,
431
+ vec![ 0.42864544605 , 0.2869813741 , 0.737752073825 , 0.431011663625 ] ,
432
+ ) ;
433
+
434
+ assert ! (
435
+ & DenseMatrix :: from_2d_vec( & vec![ deserialized_scaler. stds] ) . approximate_eq(
436
+ & DenseMatrix :: from_2d_array( & [ & [
437
+ 0.29426447500954 ,
438
+ 0.16758497615485 ,
439
+ 0.20820945786863 ,
440
+ 0.23329718831165
441
+ ] , ] ) ,
442
+ 0.00000000000001
443
+ )
444
+ )
445
+ }
403
446
}
404
447
}
0 commit comments