Skip to content

Commit 81bbdbb

Browse files
committed
fix issue #782 + add test_chrf
1 parent b4ea78c commit 81bbdbb

File tree

2 files changed

+91
-24
lines changed

2 files changed

+91
-24
lines changed

neuralmonkey/evaluators/chrf.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Dict
22
from typeguard import check_argument_types
3+
import numpy as np
34
from neuralmonkey.evaluators.evaluator import Evaluator
45

56
# pylint: disable=invalid-name
@@ -25,7 +26,6 @@ def __init__(self,
2526
super().__init__(name)
2627

2728
self.n = n
28-
self.max_ord = n
2929
self.beta_2 = beta**2
3030

3131
self.ignored = [] # type: List[str]
@@ -58,44 +58,39 @@ def score_instance(self,
5858
/ ((self.beta_2 * precision) + recall))
5959

6060
def chr_r(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
61-
recall = 0.0
61+
count_all = np.zeros(self.n)
62+
count_matched = np.zeros(self.n)
6263
for m in range(1, self.n + 1):
63-
count_all = 0
64-
count_matched = 0
6564
for ngr in ref_ngrams[m - 1]:
6665
ref_count = ref_ngrams[m - 1][ngr]
67-
count_all += ref_count
66+
count_all[m - 1] += ref_count
6867
if ngr in hyp_ngrams[m - 1]:
69-
count_matched += min(ref_count, hyp_ngrams[m - 1][ngr])
70-
# Catch division by zero
71-
if count_all != 0.0:
72-
recall += count_matched / count_all
73-
return recall / float(self.max_ord)
68+
count_matched[m - 1] += min(
69+
ref_count, hyp_ngrams[m - 1][ngr])
70+
return np.mean(np.divide(
71+
count_matched, count_all, out=np.ones_like(count_all),
72+
where=(count_all!=0)))
7473

7574
def chr_p(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
76-
precision = 0.0
75+
count_all = np.zeros(self.n)
76+
count_matched = np.zeros(self.n)
7777
for m in range(1, self.n + 1):
78-
count_all = 0
79-
count_matched = 0
8078
for ngr in hyp_ngrams[m - 1]:
8179
hyp_count = hyp_ngrams[m - 1][ngr]
82-
count_all += hyp_count
80+
count_all[m - 1] += hyp_count
8381
if ngr in ref_ngrams[m - 1]:
84-
count_matched += min(hyp_count, ref_ngrams[m - 1][ngr])
85-
# Catch division by zero
86-
if count_all != 0.0:
87-
precision += count_matched / count_all
88-
89-
return precision / float(self.max_ord)
82+
count_matched[m - 1] += min(
83+
hyp_count, ref_ngrams[m - 1][ngr])
84+
return np.mean(np.divide(
85+
count_matched, count_all, out=np.ones_like(count_all),
86+
where=(count_all!=0)))
9087

9188
def _get_ngrams(self, tokens: List[str], n: int) -> NGramDicts:
92-
if len(tokens) < n:
93-
self.max_ord = len(tokens)
94-
9589
ngr_dicts = []
9690
for m in range(1, n + 1):
9791
ngr_dict = {} # type: Dict[str, int]
98-
for i in range(m, len(tokens)):
92+
# if m > len(tokens), return an empty dict
93+
for i in range(m, len(tokens) + 1):
9994
ngr = "".join(tokens[i - m:i])
10095
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
10196
ngr_dicts.append(ngr_dict)

neuralmonkey/tests/test_chrf.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python3.5
2+
3+
4+
import unittest
5+
6+
from neuralmonkey.evaluators.chrf import ChrFEvaluator
7+
8+
9+
CORPUS_DECODED = [
10+
"colorful thoughts furiously sleep",
11+
"little piglet slept all night",
12+
"working working working working working be be be be be be be",
13+
"ich bin walrus",
14+
"walrus for präsident"
15+
]
16+
17+
CORPUS_REFERENCE = [
18+
"the colorless ideas slept furiously",
19+
"pooh slept all night",
20+
"working class hero is something to be",
21+
"I am the working class walrus",
22+
"walrus for president"
23+
]
24+
25+
TOKENS = ["a", "b", "a"]
26+
NGRAMS = [
27+
{"a": 2, "b" : 1},
28+
{"ab": 1, "ba" : 1},
29+
{"aba" : 1},
30+
{}]
31+
32+
33+
DECODED = [d.split() for d in CORPUS_DECODED]
34+
REFERENCE = [r.split() for r in CORPUS_REFERENCE]
35+
36+
FUNC = ChrFEvaluator()
37+
FUNC_P = FUNC.chr_p
38+
FUNC_R = FUNC.chr_r
39+
FUNC_NGRAMS = FUNC._get_ngrams
40+
41+
class TestChrF(unittest.TestCase):
42+
43+
def test_empty_decoded(self):
44+
# Recall == 0.0
45+
self.assertEqual(FUNC([[] for _ in DECODED], REFERENCE), 0.0)
46+
47+
def test_empty_reference(self):
48+
# Precision == 0.0
49+
self.assertEqual(FUNC([[] for _ in REFERENCE], DECODED), 0.0)
50+
51+
def test_identical(self):
52+
self.assertEqual(FUNC(REFERENCE, REFERENCE), 1.0)
53+
54+
def test_empty_sentence(self):
55+
ref_empty = REFERENCE + [[]]
56+
out_empty = DECODED + [["something"]]
57+
score = FUNC(out_empty, ref_empty)
58+
self.assertAlmostEqual(score, 0.38, delta=10)
59+
60+
def test_chrf(self):
61+
score = FUNC(DECODED, REFERENCE)
62+
self.assertAlmostEqual(score, 0.46, delta=10)
63+
64+
def test_get_ngrams(self):
65+
tokens = ["a", "b", "a"]
66+
ngrams_out = FUNC_NGRAMS(tokens, 4)
67+
self.assertEqual(len(ngrams_out), 4)
68+
for i, _ in enumerate(NGRAMS):
69+
self.assertDictEqual(ngrams_out[i], NGRAMS[i])
70+
71+
if __name__ == "__main__":
72+
unittest.main()

0 commit comments

Comments
 (0)