Skip to content

Commit 5a2520e

Browse files
authored
Merge pull request #23 from huggingface/training
Training
2 parents 5d9ef6b + 4b858d6 commit 5a2520e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+13894
-246
lines changed

.gitignore

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ __pycache__/
1414
/include/
1515
/lib/
1616
/pip-selfcheck.json
17-
neuralcoref/data/*
18-
neuralcoref/train/*
19-
.cache
17+
/runs/*
18+
test_corefs.txt
19+
test_mentions.txt
20+
.cache
21+
/.vscode/*
22+
/.vscode

neuralcoref/algorithm.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from __future__ import unicode_literals
55
from __future__ import print_function
66

7-
from pprint import pprint
8-
7+
import sys
98
import os
109
import spacy
1110
import numpy as np
1211

13-
from neuralcoref.data import Data, MENTION_TYPE, NO_COREF_LIST
12+
from neuralcoref.compat import unicode_
13+
from neuralcoref.document import Document, MENTION_TYPE, NO_COREF_LIST
1414

1515
PACKAGE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
1616

@@ -22,24 +22,28 @@
2222
#######################
2323
###### CLASSES ########
2424

25-
class Model:
25+
class Model(object):
2626
'''
2727
Coreference neural model
2828
'''
2929
def __init__(self, model_path):
3030
weights, biases = [], []
3131
for file in sorted(os.listdir(model_path)):
3232
if file.startswith("single_mention_weights"):
33-
weights.append(np.load(os.path.join(model_path, file)))
33+
w = np.load(os.path.join(model_path, file))
34+
weights.append(w)
3435
if file.startswith("single_mention_bias"):
35-
biases.append(np.load(os.path.join(model_path, file)))
36+
w = np.load(os.path.join(model_path, file))
37+
biases.append(w)
3638
self.single_mention_model = list(zip(weights, biases))
3739
weights, biases = [], []
3840
for file in sorted(os.listdir(model_path)):
3941
if file.startswith("pair_mentions_weights"):
40-
weights.append(np.load(os.path.join(model_path, file)))
42+
w = np.load(os.path.join(model_path, file))
43+
weights.append(w)
4144
if file.startswith("pair_mentions_bias"):
42-
biases.append(np.load(os.path.join(model_path, file)))
45+
w = np.load(os.path.join(model_path, file))
46+
biases.append(w)
4347
self.pair_mentions_model = list(zip(weights, biases))
4448

4549
def _score(self, features, layers):
@@ -49,8 +53,8 @@ def _score(self, features, layers):
4953
features = np.maximum(features, 0) # ReLU
5054
return np.sum(features)
5155

52-
def get_single_mention_score(self, mention_embedding, anaphoricity_features):
53-
first_layer_input = np.concatenate([mention_embedding,
56+
def get_single_mention_score(self, mention, anaphoricity_features):
57+
first_layer_input = np.concatenate([mention.embedding,
5458
anaphoricity_features], axis=0)[:, np.newaxis]
5559
return self._score(first_layer_input, self.single_mention_model)
5660

@@ -61,32 +65,31 @@ def get_pair_mentions_score(self, antecedent, mention, pair_features):
6165
return self._score(first_layer_input, self.pair_mentions_model)
6266

6367

64-
class Coref:
68+
class Coref(object):
6569
'''
6670
Main coreference resolution algorithm
6771
'''
68-
def __init__(self, nlp=None, greedyness=0.5, max_dist=50, max_dist_match=500, conll=None, use_no_coref_list=True, debug=False):
72+
def __init__(self, nlp=None, greedyness=0.5, max_dist=50, max_dist_match=500, conll=None,
73+
use_no_coref_list=True, debug=False):
6974
self.greedyness = greedyness
7075
self.max_dist = max_dist
7176
self.max_dist_match = max_dist_match
7277
self.debug = debug
73-
78+
model_path = os.path.join(PACKAGE_DIRECTORY, "weights/conll/" if conll is not None else "weights/")
79+
trained_embed_path = os.path.join(PACKAGE_DIRECTORY, "weights/")
80+
print("Loading neuralcoref model from", model_path)
81+
self.coref_model = Model(model_path)
7482
if nlp is None:
7583
print("Loading spacy model")
7684
try:
7785
spacy.info('en_core_web_sm')
7886
model = 'en_core_web_sm'
7987
except IOError:
8088
print("No spacy 2 model detected, using spacy1 'en' model")
89+
spacy.info('en')
8190
model = 'en'
8291
nlp = spacy.load(model)
83-
84-
model_path = os.path.join(PACKAGE_DIRECTORY, "weights/conll/" if conll is not None else "weights/")
85-
embed_model_path = os.path.join(PACKAGE_DIRECTORY, "weights/")
86-
print("loading model from", model_path)
87-
self.data = Data(nlp, model_path=embed_model_path, conll=conll, use_no_coref_list=use_no_coref_list, consider_speakers=conll)
88-
self.coref_model = Model(model_path)
89-
92+
self.data = Document(nlp, conll=conll, use_no_coref_list=use_no_coref_list, trained_embed_path=trained_embed_path)
9093
self.clusters = {}
9194
self.mention_to_cluster = []
9295
self.mentions_single_scores = {}
@@ -129,13 +132,22 @@ def _merge_coreference_clusters(self, ant_idx, mention_idx):
129132

130133
del self.clusters[remove_id]
131134

135+
def remove_singletons_clusters(self):
136+
remove_id = []
137+
for key, mentions in self.clusters.items():
138+
if len(mentions) == 1:
139+
remove_id.append(key)
140+
self.mention_to_cluster[key] = None
141+
for rem in remove_id:
142+
del self.clusters[rem]
143+
132144
def display_clusters(self):
133145
'''
134146
Print clusters informations
135147
'''
136148
print(self.clusters)
137149
for key, mentions in self.clusters.items():
138-
print("cluster", key, "(", ", ".join(str(self.data[m]) for m in mentions), ")")
150+
print("cluster", key, "(", ", ".join(unicode_(self.data[m]) for m in mentions), ")")
139151

140152
###################################
141153
####### MAIN COREF FUNCTIONS ######
@@ -150,11 +162,10 @@ def run_coref_on_mentions(self, mentions):
150162
for mention_idx, ant_list in self.data.get_candidate_pairs(mentions, self.max_dist, self.max_dist_match):
151163
mention = self.data[mention_idx]
152164
feats_, ana_feats = self.data.get_single_mention_features(mention)
153-
anaphoricity_score = self.coref_model.get_single_mention_score(mention.embedding, ana_feats)
154-
self.mentions_single_scores[mention_idx] = anaphoricity_score
165+
single_score = self.coref_model.get_single_mention_score(mention, ana_feats)
166+
self.mentions_single_scores[mention_idx] = single_score
155167
self.mentions_single_features[mention_idx] = {"spansEmbeddings": mention.spans_embeddings_, "wordsEmbeddings": mention.words_embeddings_, "features": feats_}
156-
157-
best_score = anaphoricity_score - 50 * (self.greedyness - 0.5)
168+
best_score = single_score - 50 * (self.greedyness - 0.5)
158169
for ant_idx in ant_list:
159170
antecedent = self.data[ant_idx]
160171
feats_, pwf = self.data.get_pair_mentions_features(antecedent, mention)
@@ -164,7 +175,6 @@ def run_coref_on_mentions(self, mentions):
164175
"antecedentWordsEmbeddings": antecedent.words_embeddings_,
165176
"mentionSpansEmbeddings": mention.spans_embeddings_,
166177
"mentionWordsEmbeddings": mention.words_embeddings_ }
167-
168178
if score > best_score:
169179
best_score = score
170180
best_ant[mention_idx] = ant_idx
@@ -173,25 +183,29 @@ def run_coref_on_mentions(self, mentions):
173183
self._merge_coreference_clusters(best_ant[mention_idx], mention_idx)
174184
return (n_ant, best_ant)
175185

176-
def run_coref_on_utterances(self, last_utterances_added=False, follow_chains=True):
186+
def run_coref_on_utterances(self, last_utterances_added=False, follow_chains=True, debug=False):
177187
''' Run the coreference model on some utterances
178188
179189
Arg:
180190
last_utterances_added: run the coreference model over the last utterances added to the data
181191
follow_chains: follow coreference chains over previous utterances
182192
'''
193+
if debug: print("== run_coref_on_utterances == start")
183194
self._prepare_clusters()
195+
if debug: self.display_clusters()
184196
mentions = list(self.data.get_candidate_mentions(last_utterances_added=last_utterances_added))
185197
n_ant, antecedents = self.run_coref_on_mentions(mentions)
186198
mentions = antecedents.values()
187-
if follow_chains and n_ant > 0:
199+
if follow_chains and last_utterances_added and n_ant > 0:
188200
i = 0
189201
while i < MAX_FOLLOW_UP:
190202
i += 1
191203
n_ant, antecedents = self.run_coref_on_mentions(mentions)
192204
mentions = antecedents.values()
193205
if n_ant == 0:
194206
break
207+
if debug: self.display_clusters()
208+
if debug: print("== run_coref_on_utterances == end")
195209

196210
def one_shot_coref(self, utterances, utterances_speakers_id=None, context=None,
197211
context_speakers_id=None, speakers_names=None):
@@ -236,7 +250,7 @@ def continuous_coref(self, utterances, utterances_speakers_id=None, speakers_nam
236250

237251
def get_utterances(self, last_utterances_added=True):
238252
''' Retrieve the list of parsed uterrances'''
239-
if last_utterances_added:
253+
if last_utterances_added and len(self.data.last_utterances_loaded):
240254
return [self.data.utterances[idx] for idx in self.data.last_utterances_loaded]
241255
else:
242256
return self.data.utterances
@@ -272,9 +286,10 @@ def get_scores(self):
272286
return {"single_scores": self.mentions_single_scores,
273287
"pair_scores": self.mentions_pairs_scores}
274288

275-
def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
289+
def get_clusters(self, remove_singletons=False, use_no_coref_list=False):
276290
''' Retrieve cleaned clusters'''
277291
clusters = self.clusters
292+
mention_to_cluster = self.mention_to_cluster
278293
remove_id = []
279294
if use_no_coref_list:
280295
for key, mentions in clusters.items():
@@ -289,7 +304,7 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
289304
for key, mentions in clusters.items():
290305
if self.data.mentions[key].lower_ in NO_COREF_LIST:
291306
remove_id.append(key)
292-
self.mention_to_cluster[key] = None
307+
mention_to_cluster[key] = None
293308
if mentions:
294309
added[mentions[0]] = mentions
295310
for rem in remove_id:
@@ -301,11 +316,11 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
301316
for key, mentions in clusters.items():
302317
if len(mentions) == 1:
303318
remove_id.append(key)
304-
self.mention_to_cluster[key] = None
319+
mention_to_cluster[key] = None
305320
for rem in remove_id:
306321
del clusters[rem]
307322

308-
return clusters
323+
return clusters, mention_to_cluster
309324

310325
def get_most_representative(self, last_utterances_added=True, use_no_coref_list=True):
311326
'''
@@ -314,7 +329,7 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
314329
Return:
315330
Dictionnary of {original_mention: most_representative_resolved_mention, ...}
316331
'''
317-
clusters = self.get_clusters(remove_singletons=True, use_no_coref_list=use_no_coref_list)
332+
clusters, _ = self.get_clusters(remove_singletons=True, use_no_coref_list=use_no_coref_list)
318333
coreferences = {}
319334
for key in self.data.get_candidate_mentions(last_utterances_added=last_utterances_added):
320335
if self.mention_to_cluster[key] is None:
@@ -333,3 +348,19 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
333348
representative = mention
334349

335350
return coreferences
351+
352+
if __name__ == '__main__':
353+
coref = Coref(use_no_coref_list=False)
354+
if len(sys.argv) > 1:
355+
sent = sys.argv[1]
356+
coref.one_shot_coref(sent)
357+
else:
358+
coref.one_shot_coref(u"Yes, I noticed that many friends, around me received it. It seems that almost everyone received this SMS.")#u"My sister has a dog. She loves him.")
359+
mentions = coref.get_mentions()
360+
print(mentions)
361+
362+
utterances = coref.get_utterances()
363+
print(utterances)
364+
365+
resolved_utterance_text = coref.get_resolved_utterances()
366+
print(resolved_utterance_text)

neuralcoref/bld.bat

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"%PYTHON%" setup.py install --single-version-externally-managed --record=record.txt
2+
if errorlevel 1 exit 1

neuralcoref/build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
$PYTHON setup.py install --single-version-externally-managed --record=record.txt # Python command to install the script.

neuralcoref/checkpoints/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Ignore everything in this directory
2+
*
3+
# Except this file
4+
!.gitignore

neuralcoref/compat.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# coding: utf8
2+
"""Py2/3 compatibility"""
3+
import sys
4+
5+
is_python2 = int(sys.version[0]) == 2
6+
is_windows = sys.platform.startswith('win')
7+
is_linux = sys.platform.startswith('linux')
8+
is_osx = sys.platform == 'darwin'
9+
10+
if is_python2:
11+
bytes_ = str
12+
unicode_ = unicode
13+
string_types = (str, unicode)
14+
chr_ = unichr
15+
16+
def unicode_to_bytes(s, encoding='utf8', errors='strict'):
17+
return s.encode(encoding=encoding, errors=errors)
18+
19+
def bytes_to_unicode(b, encoding='utf8', errors='strict'):
20+
return unicode_(b, encoding=encoding, errors=errors)
21+
22+
else:
23+
bytes_ = bytes
24+
unicode_ = str
25+
string_types = (bytes, str)
26+
chr_ = chr
27+
28+
def unicode_to_bytes(s, encoding='utf8', errors='strict'):
29+
return s.encode(encoding=encoding, errors=errors)
30+
31+
def bytes_to_unicode(b, encoding='utf8', errors='strict'):
32+
return b.decode(encoding=encoding, errors=errors)

0 commit comments

Comments
 (0)