4
4
from builtins import map , zip
5
5
import json
6
6
import logging
7
+
7
8
logging .basicConfig (level = logging .INFO , format = '[%(asctime)s] %(message)s' , datefmt = '%d/%m/%Y %H:%M:%S' )
8
9
logger = logging .getLogger (__name__ )
9
10
12
13
13
14
# EVALUATION FUNCTIONS SELECTOR
14
15
16
+ def get_sacrebleu_score (pred_list , verbose , extra_vars , split ):
17
+ """
18
+ SacreBLEU! metrics
19
+ :param pred_list: dictionary of hypothesis sentences (id, sentence)
20
+ :param verbose: if greater than 0 the metric measures are printed out
21
+ :param extra_vars: extra variables, here are:
22
+ extra_vars['references'] - dict mapping sample indices to list with all valid captions (id, [sentences])
23
+ extra_vars['tokenize_f'] - tokenization function used during model training (used again for validation)
24
+ extra_vars['detokenize_f'] - detokenization function used during model training (used again for validation)
25
+ extra_vars['tokenize_hypotheses'] - Whether tokenize or not the hypotheses during evaluation
26
+ :param split: split on which we are evaluating
27
+ :return: Dictionary with the coco scores
28
+ """
29
+ import sacrebleu
30
+ gts = extra_vars [split ]['references' ]
31
+ if extra_vars .get ('tokenize_hypotheses' , False ):
32
+ hypo = [list (map (
33
+ lambda x : extra_vars ['tokenize_f' ](x .strip ()), line )) for line in pred_list ]
34
+ else :
35
+ hypo = [line .strip () for line in pred_list ]
36
+
37
+ initial_references = gts .get (0 )
38
+ if initial_references is None :
39
+ raise ValueError ('You need to provide at least one reference' )
40
+
41
+ num_references = len (initial_references )
42
+ refs = [[] for _ in range (num_references )]
43
+ for references in gts .values ():
44
+ assert len (references ) == num_references , '"get_sacrebleu_score" does not support a different number of references per sample.'
45
+ for ref_idx , reference in enumerate (references ):
46
+ # De/Tokenize refereces if needed
47
+ tokenized_ref = extra_vars ['tokenize_f' ](reference ) if extra_vars .get ('tokenize_references' , False )\
48
+ else reference
49
+ detokenized_ref = extra_vars ['detokenize_f' ](tokenized_ref ) if extra_vars .get ('apply_detokenization' , False ) else tokenized_ref
50
+ refs [ref_idx ].append (detokenized_ref )
51
+
52
+ scorers = [
53
+ (sacrebleu .corpus_bleu , "Bleu_4" ),
54
+ ]
55
+
56
+ final_scores = {}
57
+ for scorer , method in scorers :
58
+ score = scorer (hypo , refs )
59
+ final_scores [method ] = score .score
60
+
61
+ if verbose > 0 :
62
+ logger .info ('Computing SacreBleu scores on the %s split...' % split )
63
+ for metric in sorted (final_scores ):
64
+ value = final_scores [metric ]
65
+ logger .info (metric + ': ' + str (value ))
66
+
67
+ return final_scores
68
+
69
+
15
70
def get_coco_score (pred_list , verbose , extra_vars , split ):
16
71
"""
17
72
COCO challenge metrics
@@ -91,7 +146,7 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
91
146
import datetime
92
147
import os
93
148
from pycocoevalcap .vqa import vqaEval , visual_qa
94
- from read_write import list2vqa
149
+ from keras_wrapper . extra . read_write import list2vqa
95
150
96
151
quesFile = extra_vars [split ]['quesFile' ]
97
152
annFile = extra_vars [split ]['annFile' ]
@@ -103,7 +158,8 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
103
158
# create vqa object and vqaRes object
104
159
vqa_ = visual_qa .VQA (annFile , quesFile )
105
160
vqaRes = vqa_ .loadRes (resFile , quesFile )
106
- vqaEval_ = vqaEval .VQAEval (vqa_ , vqaRes , n = 2 ) # n is precision of accuracy (number of places after decimal), default is 2
161
+ vqaEval_ = vqaEval .VQAEval (vqa_ , vqaRes ,
162
+ n = 2 ) # n is precision of accuracy (number of places after decimal), default is 2
107
163
vqaEval_ .evaluate ()
108
164
os .remove (resFile ) # remove temporal file
109
165
@@ -189,7 +245,8 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
189
245
190
246
if verbose > 0 :
191
247
logger .info (
192
- '"coverage_error" (best: avg labels per sample = %f): %f' % (float (np .sum (y_gt )) / float (n_samples ), coverr ))
248
+ '"coverage_error" (best: avg labels per sample = %f): %f' % (
249
+ float (np .sum (y_gt )) / float (n_samples ), coverr ))
193
250
logger .info ('Label Ranking "average_precision" (best: 1.0): %f' % avgprec )
194
251
logger .info ('Label "ranking_loss" (best: 0.0): %f' % rankloss )
195
252
logger .info ('precision: %f' % precision )
@@ -204,9 +261,6 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
204
261
'f1' : f1 }
205
262
206
263
207
- import numpy as np
208
-
209
-
210
264
def multiclass_metrics (pred_list , verbose , extra_vars , split ):
211
265
"""
212
266
Multiclass classification metrics. See multilabel ranking metrics in sklearn library for more info:
0 commit comments