Skip to content

Commit 6581c68

Browse files
authored
Merge pull request #223 from lvapeab/master
Add evaluation using SacreBleu
2 parents fe7f41b + 50ab7b0 commit 6581c68

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

keras_wrapper/extra/evaluation.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from builtins import map, zip
55
import json
66
import logging
7+
78
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
89
logger = logging.getLogger(__name__)
910

@@ -12,6 +13,60 @@
1213

1314
# EVALUATION FUNCTIONS SELECTOR
1415

16+
def get_sacrebleu_score(pred_list, verbose, extra_vars, split):
17+
"""
18+
SacreBLEU! metrics
19+
:param pred_list: dictionary of hypothesis sentences (id, sentence)
20+
:param verbose: if greater than 0 the metric measures are printed out
21+
:param extra_vars: extra variables, here are:
22+
extra_vars['references'] - dict mapping sample indices to list with all valid captions (id, [sentences])
23+
extra_vars['tokenize_f'] - tokenization function used during model training (used again for validation)
24+
extra_vars['detokenize_f'] - detokenization function used during model training (used again for validation)
25+
extra_vars['tokenize_hypotheses'] - Whether tokenize or not the hypotheses during evaluation
26+
:param split: split on which we are evaluating
27+
:return: Dictionary with the coco scores
28+
"""
29+
import sacrebleu
30+
gts = extra_vars[split]['references']
31+
if extra_vars.get('tokenize_hypotheses', False):
32+
hypo = [list(map(
33+
lambda x: extra_vars['tokenize_f'](x.strip()), line)) for line in pred_list]
34+
else:
35+
hypo = [line.strip() for line in pred_list]
36+
37+
initial_references = gts.get(0)
38+
if initial_references is None:
39+
raise ValueError('You need to provide at least one reference')
40+
41+
num_references = len(initial_references)
42+
refs = [[] for _ in range(num_references)]
43+
for references in gts.values():
44+
assert len(references) == num_references, '"get_sacrebleu_score" does not support a different number of references per sample.'
45+
for ref_idx, reference in enumerate(references):
46+
# De/Tokenize refereces if needed
47+
tokenized_ref = extra_vars['tokenize_f'](reference) if extra_vars.get('tokenize_references', False)\
48+
else reference
49+
detokenized_ref = extra_vars['detokenize_f'](tokenized_ref) if extra_vars.get('apply_detokenization', False) else tokenized_ref
50+
refs[ref_idx].append(detokenized_ref)
51+
52+
scorers = [
53+
(sacrebleu.corpus_bleu, "Bleu_4"),
54+
]
55+
56+
final_scores = {}
57+
for scorer, method in scorers:
58+
score = scorer(hypo, refs)
59+
final_scores[method] = score.score
60+
61+
if verbose > 0:
62+
logger.info('Computing SacreBleu scores on the %s split...' % split)
63+
for metric in sorted(final_scores):
64+
value = final_scores[metric]
65+
logger.info(metric + ': ' + str(value))
66+
67+
return final_scores
68+
69+
1570
def get_coco_score(pred_list, verbose, extra_vars, split):
1671
"""
1772
COCO challenge metrics
@@ -91,7 +146,7 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
91146
import datetime
92147
import os
93148
from pycocoevalcap.vqa import vqaEval, visual_qa
94-
from read_write import list2vqa
149+
from keras_wrapper.extra.read_write import list2vqa
95150

96151
quesFile = extra_vars[split]['quesFile']
97152
annFile = extra_vars[split]['annFile']
@@ -103,7 +158,8 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
103158
# create vqa object and vqaRes object
104159
vqa_ = visual_qa.VQA(annFile, quesFile)
105160
vqaRes = vqa_.loadRes(resFile, quesFile)
106-
vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2
161+
vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes,
162+
n=2) # n is precision of accuracy (number of places after decimal), default is 2
107163
vqaEval_.evaluate()
108164
os.remove(resFile) # remove temporal file
109165

@@ -189,7 +245,8 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
189245

190246
if verbose > 0:
191247
logger.info(
192-
'"coverage_error" (best: avg labels per sample = %f): %f' % (float(np.sum(y_gt)) / float(n_samples), coverr))
248+
'"coverage_error" (best: avg labels per sample = %f): %f' % (
249+
float(np.sum(y_gt)) / float(n_samples), coverr))
193250
logger.info('Label Ranking "average_precision" (best: 1.0): %f' % avgprec)
194251
logger.info('Label "ranking_loss" (best: 0.0): %f' % rankloss)
195252
logger.info('precision: %f' % precision)
@@ -204,9 +261,6 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
204261
'f1': f1}
205262

206263

207-
import numpy as np
208-
209-
210264
def multiclass_metrics(pred_list, verbose, extra_vars, split):
211265
"""
212266
Multiclass classification metrics. See multilabel ranking metrics in sklearn library for more info:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ toolz
44
cloudpickle
55
matplotlib
66
sacremoses
7+
sacrebleu
78
scipy
89
future
910
cython

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'cloudpickle',
3131
'matplotlib',
3232
'sacremoses',
33+
'sacrebleu',
3334
'scipy',
3435
'future',
3536
'cython',

tests/extra/test_wrapper_evaluation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
11
import pytest
2-
from keras_wrapper.extra.evaluation import *
2+
import numpy as np
3+
from keras_wrapper.extra.evaluation import get_sacrebleu_score, get_coco_score, multilabel_metrics, compute_perplexity
34

45

6+
def test_get_sacrebleu_score():
7+
pred_list = ['Prediction 1 X W Z', 'Prediction 2 X W Z', 'Prediction 3 X W Z']
8+
extra_vars = {'val': {'references': {0: ['Prediction 1 X W Z', 'Prediction 5'],
9+
1: ['Prediction 2 X W Z', 'X Y Z'],
10+
2: ['Prediction 3 X W Z', 'Prediction 5']}},
11+
12+
'test': {'references': {0: ['Prediction 2 X W Z'],
13+
1: ['Prediction 3 X W Z'],
14+
2: ['Prediction 1 X W Z']}}
15+
}
16+
val_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'val')
17+
assert np.allclose(val_scores['Bleu_4'], 100.0, atol=1e6)
18+
19+
20+
test_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'test')
21+
assert np.allclose(test_scores['Bleu_4'], 0., atol=1e6)
22+
523
def test_get_coco_score():
624
pred_list = ['Prediction 1', 'Prediction 2', 'Prediction 3']
725
extra_vars = {'val': {'references': {0: ['Prediction 1'], 1: ['Prediction 2'],

0 commit comments

Comments
 (0)