Skip to content

Commit cf86a8b

Browse files
committed
Add the ability to read & write a fast coref model
1 parent 50e3dc3 commit cf86a8b

File tree

3 files changed

+111
-5
lines changed

3 files changed

+111
-5
lines changed

src/edu/stanford/nlp/coref/fastneural/FastNeuralCorefModel.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.io.Serializable;
44
import java.util.ArrayList;
5+
import java.util.Collections;
56
import java.util.HashMap;
67
import java.util.List;
78
import java.util.Map;
@@ -37,8 +38,9 @@ public class FastNeuralCorefModel implements Serializable {
3738
private List<SimpleMatrix> networkLayers;
3839

3940
public FastNeuralCorefModel(EmbeddingExtractor embeddingExtractor,
40-
Map<String, Integer> pairFeatureIds, Map<String, Integer> mentionFeatureIds,
41-
List<SimpleMatrix> weights) {
41+
Map<String, Integer> pairFeatureIds,
42+
Map<String, Integer> mentionFeatureIds,
43+
List<SimpleMatrix> weights) {
4244
this.embeddingExtractor = embeddingExtractor;
4345
this.pairFeatureIds = pairFeatureIds;
4446
this.mentionFeatureIds = mentionFeatureIds;
@@ -53,6 +55,31 @@ public FastNeuralCorefModel(EmbeddingExtractor embeddingExtractor,
5355
networkLayers = new ArrayList<>(weights.subList(7, weights.size()));
5456
}
5557

58+
public EmbeddingExtractor getEmbeddingExtractor() {
59+
return embeddingExtractor;
60+
}
61+
62+
public Map<String, Integer> getPairFeatureIds() {
63+
return Collections.unmodifiableMap(pairFeatureIds);
64+
}
65+
66+
public Map<String, Integer> getMentionFeatureIds() {
67+
return Collections.unmodifiableMap(mentionFeatureIds);
68+
}
69+
70+
public List<SimpleMatrix> getAllWeights() {
71+
List<SimpleMatrix> weights = new ArrayList<>();
72+
weights.add(anaphorKernel);
73+
weights.add(anaphorBias);
74+
weights.add(antecedentKernel);
75+
weights.add(anaphorBias);
76+
weights.add(pairFeaturesKernel);
77+
weights.add(pairFeaturesBias);
78+
weights.add(NARepresentation);
79+
weights.addAll(networkLayers);
80+
return Collections.unmodifiableList(weights);
81+
}
82+
5683
public double score(Mention antecedent, Mention anaphor, Counter<String> antecedentFeatures,
5784
Counter<String> anaphorFeatures, Counter<String> pairFeatures,
5885
Map<Integer, SimpleMatrix> antecedentCache, Map<Integer, SimpleMatrix> anaphorCache) {

src/edu/stanford/nlp/coref/neural/EmbeddingExtractor.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,29 @@ public class EmbeddingExtractor implements Serializable {
2828
private final String naEmbedding;
2929

3030
public EmbeddingExtractor(boolean conll, Embedding staticWordEmbeddings,
31-
Embedding tunedWordEmbeddings, String naEmbedding) {
31+
Embedding tunedWordEmbeddings, String naEmbedding) {
3232
this.conll = conll;
3333
this.staticWordEmbeddings = staticWordEmbeddings;
3434
this.tunedWordEmbeddings = tunedWordEmbeddings;
3535
this.naEmbedding = naEmbedding;
3636
}
3737

38+
public boolean isConll() {
39+
return conll;
40+
}
41+
42+
public Embedding getStaticWordEmbeddings() {
43+
return staticWordEmbeddings;
44+
}
45+
46+
public Embedding getTunedWordEmbeddings() {
47+
return tunedWordEmbeddings;
48+
}
49+
50+
public String getNAEmbedding() {
51+
return naEmbedding;
52+
}
53+
3854
public SimpleMatrix getDocumentEmbedding(Document document) {
3955
if (!conll) {
4056
return new SimpleMatrix(staticWordEmbeddings.getEmbeddingSize(), 1);

src/edu/stanford/nlp/neural/ConvertModels.java

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import org.ejml.simple.SimpleMatrix;
1616

17+
import edu.stanford.nlp.coref.fastneural.FastNeuralCorefModel;
18+
import edu.stanford.nlp.coref.neural.EmbeddingExtractor;
1719
import edu.stanford.nlp.coref.neural.NeuralCorefModel;
1820
import edu.stanford.nlp.io.IOUtils;
1921
import edu.stanford.nlp.parser.dvparser.DVModel;
@@ -34,7 +36,7 @@ public enum Stage {
3436
}
3537

3638
public enum Model {
37-
SENTIMENT, DVPARSER, COREF, EMBEDDING
39+
SENTIMENT, DVPARSER, COREF, EMBEDDING, FASTCOREF
3840
}
3941

4042
/**
@@ -222,6 +224,7 @@ public static void writeEmbedding(Embedding embedding, ObjectOutputStream out)
222224
public static Embedding readEmbedding(ObjectInputStream in)
223225
throws IOException, ClassNotFoundException
224226
{
227+
225228
Function<List<List<Double>>, SimpleMatrix> f = (x) -> toMatrix(x);
226229
Map<String, List<List<Double>>> map = ErasureUtils.uncheckedCast(in.readObject());
227230
Map<String, SimpleMatrix> vectors = transformMap(map, f);
@@ -267,6 +270,48 @@ public static NeuralCorefModel readCoref(ObjectInputStream in)
267270
return model;
268271
}
269272

273+
public static void writeFastCoref(FastNeuralCorefModel model, ObjectOutputStream out)
274+
throws IOException
275+
{
276+
Function<SimpleMatrix, List<List<Double>>> f = (SimpleMatrix x) -> fromMatrix(x);
277+
278+
EmbeddingExtractor embedding = model.getEmbeddingExtractor();
279+
out.writeObject(embedding.isConll());
280+
Embedding staticEmbedding = embedding.getStaticWordEmbeddings();
281+
if (staticEmbedding == null) {
282+
out.writeObject(false);
283+
} else {
284+
out.writeObject(true);
285+
writeEmbedding(staticEmbedding, out);
286+
}
287+
writeEmbedding(embedding.getTunedWordEmbeddings(), out);
288+
out.writeObject(embedding.getNAEmbedding());
289+
290+
out.writeObject(model.getPairFeatureIds());
291+
out.writeObject(model.getMentionFeatureIds());
292+
out.writeObject(CollectionUtils.transformAsList(model.getAllWeights(), f));
293+
}
294+
295+
public static FastNeuralCorefModel readFastCoref(ObjectInputStream in)
296+
throws IOException, ClassNotFoundException
297+
{
298+
Function<List<List<Double>>, SimpleMatrix> f = (x) -> toMatrix(x);
299+
300+
boolean conll = ErasureUtils.uncheckedCast(in.readObject());
301+
boolean hasStatic = ErasureUtils.uncheckedCast(in.readObject());
302+
Embedding staticEmbedding = (hasStatic) ? readEmbedding(in) : null;
303+
Embedding tunedEmbedding = readEmbedding(in);
304+
String naEmbedding = ErasureUtils.uncheckedCast(in.readObject());
305+
306+
EmbeddingExtractor embedding = new EmbeddingExtractor(conll, staticEmbedding, tunedEmbedding, naEmbedding);
307+
308+
Map<String, Integer> pairFeatures = ErasureUtils.uncheckedCast(in.readObject());
309+
Map<String, Integer> mentionFeatures = ErasureUtils.uncheckedCast(in.readObject());
310+
List<SimpleMatrix> weights = CollectionUtils.transformAsList(ErasureUtils.uncheckedCast(in.readObject()), f);
311+
312+
return new FastNeuralCorefModel(embedding, pairFeatures, mentionFeatures, weights);
313+
}
314+
270315
/**
271316
* This program converts a sentiment model or an RNN parser model
272317
* from EJML v23, used by CoreNLP 3.9.2, to a more recent version of
@@ -314,6 +359,12 @@ public static NeuralCorefModel readCoref(ObjectInputStream in)
314359
* <br>
315360
* <code> java edu.stanford.nlp.neural.ConvertModels -stage NEW -model EMBEDDING -input /scr/nlp/data/coref/models/neural/english/english-embeddings.INT.ser.gz -output /scr/nlp/data/coref/models/neural/english/english-embeddings.e39.ser.gz</code>
316361
* <br>
362+
* There is another coref model which isn't used in corenlp, but it might be in the future. To upgrade this, use <code>-model FASTCOREF</code>
363+
* <br>
364+
* <code> java edu.stanford.nlp.neural.ConvertModels -stage OLD -model FASTCOREF -input /scr/nlp/data/coref/models/fastneural/fast-english-model.e38.ser.gz -output /scr/nlp/data/coref/models/fastneural/fast-english-model.INT.ser.gz</code>
365+
* <br>
366+
* <code> java edu.stanford.nlp.neural.ConvertModels -stage NEW -model FASTCOREF -input /scr/nlp/data/coref/models/fastneural/fast-english-model.INT.ser.gz -output /scr/nlp/data/coref/models/fastneural/fast-english-model.e39.ser.gz</code>
367+
* <br>
317368
*
318369
* @author <a href=horatio@gmail.com>John Bauer</a>
319370
*/
@@ -331,7 +382,7 @@ public static void main(String[] args) throws IOException, ClassNotFoundExceptio
331382
try {
332383
modelType = Model.valueOf(props.getProperty("model").toUpperCase());
333384
} catch (IllegalArgumentException | NullPointerException e) {
334-
throw new IllegalArgumentException("Please specify -model, either SENTIMENT, DVPARSER, EMBEDDING, COREF");
385+
throw new IllegalArgumentException("Please specify -model, either SENTIMENT, DVPARSER, EMBEDDING, COREF, FASTCOREF");
335386
}
336387

337388
if (!props.containsKey("input")) {
@@ -399,6 +450,18 @@ public static void main(String[] args) throws IOException, ClassNotFoundExceptio
399450
in.close();
400451
IOUtils.writeObjectToFile(model, outputPath);
401452
}
453+
} else if (modelType == Model.FASTCOREF) {
454+
if (stage == Stage.OLD) {
455+
FastNeuralCorefModel model = ErasureUtils.uncheckedCast(IOUtils.readObjectFromURLOrClasspathOrFileSystem(inputPath));
456+
ObjectOutputStream out = IOUtils.writeStreamFromString(outputPath);
457+
writeFastCoref(model, out);
458+
out.close();
459+
} else {
460+
ObjectInputStream in = IOUtils.readStreamFromString(inputPath);
461+
FastNeuralCorefModel model = readFastCoref(in);
462+
in.close();
463+
IOUtils.writeObjectToFile(model, outputPath);
464+
}
402465
} else {
403466
throw new IllegalArgumentException("Unknown model type " + modelType);
404467
}

0 commit comments

Comments
 (0)