Skip to content

Commit 2c71059

Browse files
authored
Merge pull request #783 from ufal/chrf_fix
fix issue #782
2 parents b4ea78c + 98f3b3b commit 2c71059

File tree

2 files changed

+84
-35
lines changed

2 files changed

+84
-35
lines changed

neuralmonkey/evaluators/chrf.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Dict
22
from typeguard import check_argument_types
3+
import numpy as np
34
from neuralmonkey.evaluators.evaluator import Evaluator
45

56
# pylint: disable=invalid-name
@@ -25,7 +26,6 @@ def __init__(self,
2526
super().__init__(name)
2627

2728
self.n = n
28-
self.max_ord = n
2929
self.beta_2 = beta**2
3030

3131
self.ignored = [] # type: List[str]
@@ -37,11 +37,11 @@ def score_instance(self,
3737
reference: List[str]) -> float:
3838
hyp_joined = " ".join(hypothesis)
3939
hyp_chars = [x for x in list(hyp_joined) if x not in self.ignored]
40-
hyp_ngrams = self._get_ngrams(hyp_chars, self.n)
40+
hyp_ngrams = _get_ngrams(hyp_chars, self.n)
4141

4242
ref_joined = " ".join(reference)
4343
ref_chars = [x for x in list(ref_joined) if x not in self.ignored]
44-
ref_ngrams = self._get_ngrams(ref_chars, self.n)
44+
ref_ngrams = _get_ngrams(ref_chars, self.n)
4545

4646
if not hyp_chars or not ref_chars:
4747
if "".join(hyp_chars) == "".join(ref_chars):
@@ -58,48 +58,43 @@ def score_instance(self,
5858
/ ((self.beta_2 * precision) + recall))
5959

6060
def chr_r(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
61-
recall = 0.0
61+
count_all = np.zeros(self.n)
62+
count_matched = np.zeros(self.n)
6263
for m in range(1, self.n + 1):
63-
count_all = 0
64-
count_matched = 0
6564
for ngr in ref_ngrams[m - 1]:
6665
ref_count = ref_ngrams[m - 1][ngr]
67-
count_all += ref_count
66+
count_all[m - 1] += ref_count
6867
if ngr in hyp_ngrams[m - 1]:
69-
count_matched += min(ref_count, hyp_ngrams[m - 1][ngr])
70-
# Catch division by zero
71-
if count_all != 0.0:
72-
recall += count_matched / count_all
73-
return recall / float(self.max_ord)
68+
count_matched[m - 1] += min(
69+
ref_count, hyp_ngrams[m - 1][ngr])
70+
return np.mean(np.divide(
71+
count_matched, count_all, out=np.ones_like(count_all),
72+
where=(count_all != 0)))
7473

7574
def chr_p(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
76-
precision = 0.0
75+
count_all = np.zeros(self.n)
76+
count_matched = np.zeros(self.n)
7777
for m in range(1, self.n + 1):
78-
count_all = 0
79-
count_matched = 0
8078
for ngr in hyp_ngrams[m - 1]:
8179
hyp_count = hyp_ngrams[m - 1][ngr]
82-
count_all += hyp_count
80+
count_all[m - 1] += hyp_count
8381
if ngr in ref_ngrams[m - 1]:
84-
count_matched += min(hyp_count, ref_ngrams[m - 1][ngr])
85-
# Catch division by zero
86-
if count_all != 0.0:
87-
precision += count_matched / count_all
88-
89-
return precision / float(self.max_ord)
90-
91-
def _get_ngrams(self, tokens: List[str], n: int) -> NGramDicts:
92-
if len(tokens) < n:
93-
self.max_ord = len(tokens)
94-
95-
ngr_dicts = []
96-
for m in range(1, n + 1):
97-
ngr_dict = {} # type: Dict[str, int]
98-
for i in range(m, len(tokens)):
99-
ngr = "".join(tokens[i - m:i])
100-
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
101-
ngr_dicts.append(ngr_dict)
102-
return ngr_dicts
82+
count_matched[m - 1] += min(
83+
hyp_count, ref_ngrams[m - 1][ngr])
84+
return np.mean(np.divide(
85+
count_matched, count_all, out=np.ones_like(count_all),
86+
where=(count_all != 0)))
87+
88+
89+
def _get_ngrams(tokens: List[str], n: int) -> NGramDicts:
90+
ngr_dicts = []
91+
for m in range(1, n + 1):
92+
ngr_dict = {} # type: Dict[str, int]
93+
for i in range(m, len(tokens) + 1):
94+
ngr = "".join(tokens[i - m:i])
95+
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
96+
ngr_dicts.append(ngr_dict)
97+
return ngr_dicts
10398

10499

105100
# pylint: disable=invalid-name

neuralmonkey/tests/test_chrf.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3.5
2+
3+
4+
import unittest
5+
6+
from neuralmonkey.evaluators.chrf import ChrFEvaluator, _get_ngrams
7+
from neuralmonkey.tests.test_bleu import DECODED, REFERENCE
8+
9+
10+
TOKENS = ["a", "b", "a"]
11+
NGRAMS = [
12+
{"a": 2, "b": 1},
13+
{"ab": 1, "ba": 1},
14+
{"aba": 1},
15+
{}]
16+
17+
FUNC = ChrFEvaluator()
18+
FUNC_P = FUNC.chr_p
19+
FUNC_R = FUNC.chr_r
20+
21+
22+
class TestChrF(unittest.TestCase):
23+
24+
def test_empty_decoded(self):
25+
# Recall == 0.0
26+
self.assertEqual(FUNC([[] for _ in DECODED], REFERENCE), 0.0)
27+
28+
def test_empty_reference(self):
29+
# Precision == 0.0
30+
self.assertEqual(FUNC([[] for _ in REFERENCE], DECODED), 0.0)
31+
32+
def test_identical(self):
33+
self.assertEqual(FUNC(REFERENCE, REFERENCE), 1.0)
34+
35+
def test_empty_sentence(self):
36+
ref_empty = REFERENCE + [[]]
37+
out_empty = DECODED + [["something"]]
38+
score = FUNC(out_empty, ref_empty)
39+
self.assertAlmostEqual(score, 0.38, delta=10)
40+
41+
def test_chrf(self):
42+
score = FUNC(DECODED, REFERENCE)
43+
self.assertAlmostEqual(score, 0.46, delta=10)
44+
45+
def test_get_ngrams(self):
46+
tokens = ["a", "b", "a"]
47+
ngrams_out = _get_ngrams(tokens, 4)
48+
self.assertEqual(len(ngrams_out), 4)
49+
for i, _ in enumerate(NGRAMS):
50+
self.assertDictEqual(ngrams_out[i], NGRAMS[i])
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

0 commit comments

Comments
 (0)