Skip to content

Commit 9db9939

Browse files
authored
Add serde to CategoricalNB (#30)
* Add serde to CategoricalNB * Implement PartialEq for CategoricalNBDistribution
1 parent ad3ac49 commit 9db9939

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

src/naive_bayes/categorical.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
66
use serde::{Deserialize, Serialize};
77

88
/// Naive Bayes classifier for categorical features
9-
#[derive(Debug)]
9+
#[derive(Serialize, Deserialize, Debug)]
1010
struct CategoricalNBDistribution<T: RealNumber> {
1111
class_labels: Vec<T>,
1212
class_priors: Vec<T>,
1313
coefficients: Vec<Vec<Vec<T>>>,
1414
}
1515

16+
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
17+
fn eq(&self, other: &Self) -> bool {
18+
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
19+
if self.coefficients.len() != other.coefficients.len() {
20+
return false;
21+
}
22+
for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) {
23+
if a.len() != b.len() {
24+
return false;
25+
}
26+
for (a_i, b_i) in a.iter().zip(b.iter()) {
27+
if a_i.len() != b_i.len() {
28+
return false;
29+
}
30+
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
31+
if (*a_i_j - *b_i_j).abs() > T::epsilon() {
32+
return false;
33+
}
34+
}
35+
}
36+
}
37+
true
38+
} else {
39+
false
40+
}
41+
}
42+
}
43+
1644
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
1745
fn prior(&self, class_index: usize) -> T {
1846
if class_index >= self.class_labels.len() {
@@ -181,7 +209,7 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
181209
}
182210

183211
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
184-
#[derive(Debug)]
212+
#[derive(Serialize, Deserialize, Debug, PartialEq)]
185213
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
186214
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
187215
}
@@ -269,4 +297,32 @@ mod tests {
269297
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
270298
);
271299
}
300+
301+
#[test]
302+
fn serde() {
303+
let x = DenseMatrix::<f64>::from_2d_array(&[
304+
&[3., 4., 0., 1.],
305+
&[3., 0., 0., 1.],
306+
&[4., 4., 1., 2.],
307+
&[4., 2., 4., 3.],
308+
&[4., 2., 4., 2.],
309+
&[4., 1., 1., 0.],
310+
&[1., 1., 1., 1.],
311+
&[0., 4., 1., 0.],
312+
&[0., 3., 2., 1.],
313+
&[0., 3., 1., 1.],
314+
&[3., 4., 0., 1.],
315+
&[3., 4., 2., 4.],
316+
&[0., 3., 1., 2.],
317+
&[0., 4., 1., 2.],
318+
]);
319+
320+
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
321+
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
322+
323+
let deserialized_cnb: CategoricalNB<f64, DenseMatrix<f64>> =
324+
serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap();
325+
326+
assert_eq!(cnb, deserialized_cnb);
327+
}
272328
}

0 commit comments

Comments
 (0)