Skip to content

Commit 0f442e9

Browse files
montanalowmorenol
authored andcommitted
Handle multiclass precision/recall (#152)
* handle multiclass precision/recall
1 parent 44e4be2 commit 0f442e9

File tree

3 files changed

+97
-46
lines changed

3 files changed

+97
-46
lines changed

src/math/num.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ pub trait RealNumber:
4646
self * self
4747
}
4848

49-
/// Raw transmutation to u64
49+
/// Raw transmutation to u32
5050
fn to_f32_bits(self) -> u32;
51+
52+
/// Raw transmutation to u64
53+
fn to_f64_bits(self) -> u64;
5154
}
5255

5356
impl RealNumber for f64 {
@@ -89,6 +92,10 @@ impl RealNumber for f64 {
8992
fn to_f32_bits(self) -> u32 {
9093
self.to_bits() as u32
9194
}
95+
96+
fn to_f64_bits(self) -> u64 {
97+
self.to_bits()
98+
}
9299
}
93100

94101
impl RealNumber for f32 {
@@ -130,6 +137,10 @@ impl RealNumber for f32 {
130137
fn to_f32_bits(self) -> u32 {
131138
self.to_bits()
132139
}
140+
141+
fn to_f64_bits(self) -> u64 {
142+
self.to_bits() as u64
143+
}
133144
}
134145

135146
#[cfg(test)]

src/metrics/precision.rs

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
//!
1919
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
2020
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
21+
use std::collections::HashSet;
22+
2123
#[cfg(feature = "serde")]
2224
use serde::{Deserialize, Serialize};
2325

@@ -42,34 +44,33 @@ impl Precision {
4244
);
4345
}
4446

45-
let mut tp = 0;
46-
let mut p = 0;
47-
let n = y_true.len();
48-
for i in 0..n {
49-
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
50-
panic!(
51-
"Precision can only be applied to binary classification: {}",
52-
y_true.get(i)
53-
);
54-
}
55-
56-
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
57-
panic!(
58-
"Precision can only be applied to binary classification: {}",
59-
y_pred.get(i)
60-
);
61-
}
62-
63-
if y_pred.get(i) == T::one() {
64-
p += 1;
47+
let mut classes = HashSet::new();
48+
for i in 0..y_true.len() {
49+
classes.insert(y_true.get(i).to_f64_bits());
50+
}
51+
let classes = classes.len();
6552

66-
if y_true.get(i) == T::one() {
53+
let mut tp = 0;
54+
let mut fp = 0;
55+
for i in 0..y_true.len() {
56+
if y_pred.get(i) == y_true.get(i) {
57+
if classes == 2 {
58+
if y_true.get(i) == T::one() {
59+
tp += 1;
60+
}
61+
} else {
6762
tp += 1;
6863
}
64+
} else if classes == 2 {
65+
if y_true.get(i) == T::one() {
66+
fp += 1;
67+
}
68+
} else {
69+
fp += 1;
6970
}
7071
}
7172

72-
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
73+
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
7374
}
7475
}
7576

@@ -88,5 +89,24 @@ mod tests {
8889

8990
assert!((score1 - 0.5).abs() < 1e-8);
9091
assert!((score2 - 1.0).abs() < 1e-8);
92+
93+
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
94+
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
95+
96+
let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
97+
assert!((score3 - 0.5).abs() < 1e-8);
98+
}
99+
100+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
101+
#[test]
102+
fn precision_multiclass() {
103+
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
104+
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
105+
106+
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
107+
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
108+
109+
assert!((score1 - 0.333333333).abs() < 1e-8);
110+
assert!((score2 - 1.0).abs() < 1e-8);
91111
}
92112
}

src/metrics/recall.rs

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
//!
1919
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
2020
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
21+
use std::collections::HashSet;
22+
use std::convert::TryInto;
23+
2124
#[cfg(feature = "serde")]
2225
use serde::{Deserialize, Serialize};
2326

@@ -42,34 +45,32 @@ impl Recall {
4245
);
4346
}
4447

45-
let mut tp = 0;
46-
let mut p = 0;
47-
let n = y_true.len();
48-
for i in 0..n {
49-
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
50-
panic!(
51-
"Recall can only be applied to binary classification: {}",
52-
y_true.get(i)
53-
);
54-
}
55-
56-
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
57-
panic!(
58-
"Recall can only be applied to binary classification: {}",
59-
y_pred.get(i)
60-
);
61-
}
62-
63-
if y_true.get(i) == T::one() {
64-
p += 1;
48+
let mut classes = HashSet::new();
49+
for i in 0..y_true.len() {
50+
classes.insert(y_true.get(i).to_f64_bits());
51+
}
52+
let classes: i64 = classes.len().try_into().unwrap();
6553

66-
if y_pred.get(i) == T::one() {
54+
let mut tp = 0;
55+
let mut fne = 0;
56+
for i in 0..y_true.len() {
57+
if y_pred.get(i) == y_true.get(i) {
58+
if classes == 2 {
59+
if y_true.get(i) == T::one() {
60+
tp += 1;
61+
}
62+
} else {
6763
tp += 1;
6864
}
65+
} else if classes == 2 {
66+
if y_true.get(i) != T::one() {
67+
fne += 1;
68+
}
69+
} else {
70+
fne += 1;
6971
}
7072
}
71-
72-
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
73+
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
7374
}
7475
}
7576

@@ -88,5 +89,24 @@ mod tests {
8889

8990
assert!((score1 - 0.5).abs() < 1e-8);
9091
assert!((score2 - 1.0).abs() < 1e-8);
92+
93+
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
94+
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
95+
96+
let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
97+
assert!((score3 - 0.66666666).abs() < 1e-8);
98+
}
99+
100+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
101+
#[test]
102+
fn recall_multiclass() {
103+
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
104+
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
105+
106+
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
107+
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
108+
109+
assert!((score1 - 0.333333333).abs() < 1e-8);
110+
assert!((score2 - 1.0).abs() < 1e-8);
91111
}
92112
}

0 commit comments

Comments
 (0)