1
1
package edu .stanford .nlp .sentiment ;
2
2
import edu .stanford .nlp .util .logging .Redwood ;
3
3
4
+ import java .text .DecimalFormat ;
4
5
import java .util .List ;
5
6
7
+ import edu .stanford .nlp .ling .Label ;
8
+ import edu .stanford .nlp .neural .rnn .RNNCoreAnnotations ;
6
9
import edu .stanford .nlp .trees .Tree ;
7
10
import edu .stanford .nlp .util .Generics ;
8
11
12
+
9
13
/** @author John Bauer */
10
14
public class Evaluate extends AbstractEvaluate {
11
15
@@ -15,12 +19,72 @@ public class Evaluate extends AbstractEvaluate {
15
19
final SentimentCostAndGradient cag ;
16
20
final SentimentModel model ;
17
21
22
+ // Count how many trees are unknown to the model
23
+ // The alternate version, ExternalEvaluate, has no concept of
24
+ // unknown, so this is exclusive to the evaluate which uses a model
25
+ int treesWithUnks ;
26
+ int treesWithUnksCorrect ;
27
+
18
28
public Evaluate (SentimentModel model ) {
19
29
super (model .op );
20
30
this .model = model ;
21
31
this .cag = new SentimentCostAndGradient (model , null );
22
32
}
23
33
34
+ @ Override
35
+ public void reset () {
36
+ super .reset ();
37
+
38
+ treesWithUnks = 0 ;
39
+ treesWithUnksCorrect = 0 ;
40
+ }
41
+
42
+ @ Override
43
+ public void eval (Tree tree ) {
44
+ super .eval (tree );
45
+
46
+ countUnks (tree );
47
+ }
48
+
49
+ /**
50
+ * Keep track of how many trees have at least one unknown, and how
51
+ * many of those have the top level annotation correct.
52
+ */
53
+ protected void countUnks (Tree tree ) {
54
+ List <Label > labels = tree .yield ();
55
+ boolean hasUnk = false ;
56
+ for (Label label : labels ) {
57
+ if (!model .wordVectors .containsKey (label .value ())) {
58
+ hasUnk = true ;
59
+ break ;
60
+ }
61
+ }
62
+
63
+ if (hasUnk ) {
64
+ int gold = RNNCoreAnnotations .getGoldClass (tree );
65
+ int guess = RNNCoreAnnotations .getPredictedClass (tree );
66
+
67
+ treesWithUnks += 1 ;
68
+ if (gold == guess )
69
+ treesWithUnksCorrect += 1 ;
70
+ }
71
+ }
72
+
73
+ private static final String FORMAT = "#.##" ;
74
+ protected DecimalFormat format = new DecimalFormat (FORMAT );
75
+
76
+ @ Override
77
+ public void printSummary () {
78
+ super .printSummary ();
79
+
80
+ log .info ("Saw " + treesWithUnks + " trees with at least one unknown token." );
81
+ if (treesWithUnks > 0 ) {
82
+ double percent = (float ) treesWithUnksCorrect / treesWithUnks * 100.0 ;
83
+ log .info (treesWithUnksCorrect + " / " + treesWithUnks + " trees (" + format .format (percent ) +
84
+ "%) with at least one unknown token were classified correctly at the top level." );
85
+ }
86
+ }
87
+
24
88
@ Override
25
89
public void populatePredictedLabels (List <Tree > trees ) {
26
90
for (Tree tree : trees ) {
0 commit comments