@@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
6
6
use serde:: { Deserialize , Serialize } ;
7
7
8
8
/// Naive Bayes classifier for categorical features
9
- #[ derive( Debug ) ]
9
+ #[ derive( Serialize , Deserialize , Debug ) ]
10
10
struct CategoricalNBDistribution < T : RealNumber > {
11
11
class_labels : Vec < T > ,
12
12
class_priors : Vec < T > ,
13
13
coefficients : Vec < Vec < Vec < T > > > ,
14
14
}
15
15
16
+ impl < T : RealNumber > PartialEq for CategoricalNBDistribution < T > {
17
+ fn eq ( & self , other : & Self ) -> bool {
18
+ if self . class_labels == other. class_labels && self . class_priors == other. class_priors {
19
+ if self . coefficients . len ( ) != other. coefficients . len ( ) {
20
+ return false ;
21
+ }
22
+ for ( a, b) in self . coefficients . iter ( ) . zip ( other. coefficients . iter ( ) ) {
23
+ if a. len ( ) != b. len ( ) {
24
+ return false ;
25
+ }
26
+ for ( a_i, b_i) in a. iter ( ) . zip ( b. iter ( ) ) {
27
+ if a_i. len ( ) != b_i. len ( ) {
28
+ return false ;
29
+ }
30
+ for ( a_i_j, b_i_j) in a_i. iter ( ) . zip ( b_i. iter ( ) ) {
31
+ if ( * a_i_j - * b_i_j) . abs ( ) > T :: epsilon ( ) {
32
+ return false ;
33
+ }
34
+ }
35
+ }
36
+ }
37
+ true
38
+ } else {
39
+ false
40
+ }
41
+ }
42
+ }
43
+
16
44
impl < T : RealNumber , M : Matrix < T > > NBDistribution < T , M > for CategoricalNBDistribution < T > {
17
45
fn prior ( & self , class_index : usize ) -> T {
18
46
if class_index >= self . class_labels . len ( ) {
@@ -181,7 +209,7 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
181
209
}
182
210
183
211
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
184
- #[ derive( Debug ) ]
212
+ #[ derive( Serialize , Deserialize , Debug , PartialEq ) ]
185
213
pub struct CategoricalNB < T : RealNumber , M : Matrix < T > > {
186
214
inner : BaseNaiveBayes < T , M , CategoricalNBDistribution < T > > ,
187
215
}
@@ -269,4 +297,32 @@ mod tests {
269
297
vec![ 0. , 0. , 1. , 1. , 1. , 0. , 1. , 0. , 1. , 1. , 0. , 1. , 1. , 1. ]
270
298
) ;
271
299
}
300
+
301
+ #[ test]
302
+ fn serde ( ) {
303
+ let x = DenseMatrix :: < f64 > :: from_2d_array ( & [
304
+ & [ 3. , 4. , 0. , 1. ] ,
305
+ & [ 3. , 0. , 0. , 1. ] ,
306
+ & [ 4. , 4. , 1. , 2. ] ,
307
+ & [ 4. , 2. , 4. , 3. ] ,
308
+ & [ 4. , 2. , 4. , 2. ] ,
309
+ & [ 4. , 1. , 1. , 0. ] ,
310
+ & [ 1. , 1. , 1. , 1. ] ,
311
+ & [ 0. , 4. , 1. , 0. ] ,
312
+ & [ 0. , 3. , 2. , 1. ] ,
313
+ & [ 0. , 3. , 1. , 1. ] ,
314
+ & [ 3. , 4. , 0. , 1. ] ,
315
+ & [ 3. , 4. , 2. , 4. ] ,
316
+ & [ 0. , 3. , 1. , 2. ] ,
317
+ & [ 0. , 4. , 1. , 2. ] ,
318
+ ] ) ;
319
+
320
+ let y = vec ! [ 0. , 0. , 1. , 1. , 1. , 0. , 1. , 0. , 1. , 1. , 1. , 1. , 1. , 0. ] ;
321
+ let cnb = CategoricalNB :: fit ( & x, & y, Default :: default ( ) ) . unwrap ( ) ;
322
+
323
+ let deserialized_cnb: CategoricalNB < f64 , DenseMatrix < f64 > > =
324
+ serde_json:: from_str ( & serde_json:: to_string ( & cnb) . unwrap ( ) ) . unwrap ( ) ;
325
+
326
+ assert_eq ! ( cnb, deserialized_cnb) ;
327
+ }
272
328
}
0 commit comments