Skip to content

Commit 9b983cd

Browse files
hansAngledLuffa
authored andcommitted
Sentiment evaluator: pay attention to trees with unknown words
1 parent 75ccd16 commit 9b983cd

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

src/edu/stanford/nlp/sentiment/Evaluate.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package edu.stanford.nlp.sentiment;
22
import edu.stanford.nlp.util.logging.Redwood;
33

4+
import java.text.DecimalFormat;
45
import java.util.List;
56

7+
import edu.stanford.nlp.ling.Label;
8+
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
69
import edu.stanford.nlp.trees.Tree;
710
import edu.stanford.nlp.util.Generics;
811

12+
913
/** @author John Bauer */
1014
public class Evaluate extends AbstractEvaluate {
1115

@@ -15,12 +19,72 @@ public class Evaluate extends AbstractEvaluate {
1519
final SentimentCostAndGradient cag;
1620
final SentimentModel model;
1721

22+
// Count how many trees are unknown to the model
23+
// The alternate version, ExternalEvaluate, has no concept of
24+
// unknown, so this is exclusive to the evaluate which uses a model
25+
int treesWithUnks;
26+
int treesWithUnksCorrect;
27+
1828
public Evaluate(SentimentModel model) {
1929
super(model.op);
2030
this.model = model;
2131
this.cag = new SentimentCostAndGradient(model, null);
2232
}
2333

34+
@Override
35+
public void reset() {
36+
super.reset();
37+
38+
treesWithUnks = 0;
39+
treesWithUnksCorrect = 0;
40+
}
41+
42+
@Override
43+
public void eval(Tree tree) {
44+
super.eval(tree);
45+
46+
countUnks(tree);
47+
}
48+
49+
/**
50+
* Keep track of how many trees have at least one unknown, and how
51+
* many of those have the top level annotation correct.
52+
*/
53+
protected void countUnks(Tree tree) {
54+
List<Label> labels = tree.yield();
55+
boolean hasUnk = false;
56+
for (Label label : labels) {
57+
if (!model.wordVectors.containsKey(label.value())) {
58+
hasUnk = true;
59+
break;
60+
}
61+
}
62+
63+
if (hasUnk) {
64+
int gold = RNNCoreAnnotations.getGoldClass(tree);
65+
int guess = RNNCoreAnnotations.getPredictedClass(tree);
66+
67+
treesWithUnks += 1;
68+
if (gold == guess)
69+
treesWithUnksCorrect += 1;
70+
}
71+
}
72+
73+
private static final String FORMAT = "#.##";
74+
protected DecimalFormat format = new DecimalFormat(FORMAT);
75+
76+
@Override
77+
public void printSummary() {
78+
super.printSummary();
79+
80+
log.info("Saw " + treesWithUnks + " trees with at least one unknown token.");
81+
if (treesWithUnks > 0) {
82+
double percent = (float) treesWithUnksCorrect / treesWithUnks * 100.0;
83+
log.info(treesWithUnksCorrect + " / " + treesWithUnks + " trees (" + format.format(percent) +
84+
"%) with at least one unknown token were classified correctly at the top level.");
85+
}
86+
}
87+
2488
@Override
2589
public void populatePredictedLabels(List<Tree> trees) {
2690
for (Tree tree : trees) {

0 commit comments

Comments
 (0)