4
4
from __future__ import unicode_literals
5
5
from __future__ import print_function
6
6
7
- from pprint import pprint
8
-
7
+ import sys
9
8
import os
10
9
import spacy
11
10
import numpy as np
12
11
13
- from neuralcoref .data import Data , MENTION_TYPE , NO_COREF_LIST
12
+ from neuralcoref .compat import unicode_
13
+ from neuralcoref .document import Document , MENTION_TYPE , NO_COREF_LIST
14
14
15
15
PACKAGE_DIRECTORY = os .path .dirname (os .path .abspath (__file__ ))
16
16
22
22
#######################
23
23
###### CLASSES ########
24
24
25
- class Model :
25
+ class Model ( object ) :
26
26
'''
27
27
Coreference neural model
28
28
'''
29
29
def __init__ (self , model_path ):
30
30
weights , biases = [], []
31
31
for file in sorted (os .listdir (model_path )):
32
32
if file .startswith ("single_mention_weights" ):
33
- weights .append (np .load (os .path .join (model_path , file )))
33
+ w = np .load (os .path .join (model_path , file ))
34
+ weights .append (w )
34
35
if file .startswith ("single_mention_bias" ):
35
- biases .append (np .load (os .path .join (model_path , file )))
36
+ w = np .load (os .path .join (model_path , file ))
37
+ biases .append (w )
36
38
self .single_mention_model = list (zip (weights , biases ))
37
39
weights , biases = [], []
38
40
for file in sorted (os .listdir (model_path )):
39
41
if file .startswith ("pair_mentions_weights" ):
40
- weights .append (np .load (os .path .join (model_path , file )))
42
+ w = np .load (os .path .join (model_path , file ))
43
+ weights .append (w )
41
44
if file .startswith ("pair_mentions_bias" ):
42
- biases .append (np .load (os .path .join (model_path , file )))
45
+ w = np .load (os .path .join (model_path , file ))
46
+ biases .append (w )
43
47
self .pair_mentions_model = list (zip (weights , biases ))
44
48
45
49
def _score (self , features , layers ):
@@ -49,8 +53,8 @@ def _score(self, features, layers):
49
53
features = np .maximum (features , 0 ) # ReLU
50
54
return np .sum (features )
51
55
52
- def get_single_mention_score (self , mention_embedding , anaphoricity_features ):
53
- first_layer_input = np .concatenate ([mention_embedding ,
56
+ def get_single_mention_score (self , mention , anaphoricity_features ):
57
+ first_layer_input = np .concatenate ([mention . embedding ,
54
58
anaphoricity_features ], axis = 0 )[:, np .newaxis ]
55
59
return self ._score (first_layer_input , self .single_mention_model )
56
60
@@ -61,32 +65,31 @@ def get_pair_mentions_score(self, antecedent, mention, pair_features):
61
65
return self ._score (first_layer_input , self .pair_mentions_model )
62
66
63
67
64
- class Coref :
68
+ class Coref ( object ) :
65
69
'''
66
70
Main coreference resolution algorithm
67
71
'''
68
- def __init__ (self , nlp = None , greedyness = 0.5 , max_dist = 50 , max_dist_match = 500 , conll = None , use_no_coref_list = True , debug = False ):
72
+ def __init__ (self , nlp = None , greedyness = 0.5 , max_dist = 50 , max_dist_match = 500 , conll = None ,
73
+ use_no_coref_list = True , debug = False ):
69
74
self .greedyness = greedyness
70
75
self .max_dist = max_dist
71
76
self .max_dist_match = max_dist_match
72
77
self .debug = debug
73
-
78
+ model_path = os .path .join (PACKAGE_DIRECTORY , "weights/conll/" if conll is not None else "weights/" )
79
+ trained_embed_path = os .path .join (PACKAGE_DIRECTORY , "weights/" )
80
+ print ("Loading neuralcoref model from" , model_path )
81
+ self .coref_model = Model (model_path )
74
82
if nlp is None :
75
83
print ("Loading spacy model" )
76
84
try :
77
85
spacy .info ('en_core_web_sm' )
78
86
model = 'en_core_web_sm'
79
87
except IOError :
80
88
print ("No spacy 2 model detected, using spacy1 'en' model" )
89
+ spacy .info ('en' )
81
90
model = 'en'
82
91
nlp = spacy .load (model )
83
-
84
- model_path = os .path .join (PACKAGE_DIRECTORY , "weights/conll/" if conll is not None else "weights/" )
85
- embed_model_path = os .path .join (PACKAGE_DIRECTORY , "weights/" )
86
- print ("loading model from" , model_path )
87
- self .data = Data (nlp , model_path = embed_model_path , conll = conll , use_no_coref_list = use_no_coref_list , consider_speakers = conll )
88
- self .coref_model = Model (model_path )
89
-
92
+ self .data = Document (nlp , conll = conll , use_no_coref_list = use_no_coref_list , trained_embed_path = trained_embed_path )
90
93
self .clusters = {}
91
94
self .mention_to_cluster = []
92
95
self .mentions_single_scores = {}
@@ -129,13 +132,22 @@ def _merge_coreference_clusters(self, ant_idx, mention_idx):
129
132
130
133
del self .clusters [remove_id ]
131
134
135
+ def remove_singletons_clusters (self ):
136
+ remove_id = []
137
+ for key , mentions in self .clusters .items ():
138
+ if len (mentions ) == 1 :
139
+ remove_id .append (key )
140
+ self .mention_to_cluster [key ] = None
141
+ for rem in remove_id :
142
+ del self .clusters [rem ]
143
+
132
144
def display_clusters (self ):
133
145
'''
134
146
Print clusters informations
135
147
'''
136
148
print (self .clusters )
137
149
for key , mentions in self .clusters .items ():
138
- print ("cluster" , key , "(" , ", " .join (str (self .data [m ]) for m in mentions ), ")" )
150
+ print ("cluster" , key , "(" , ", " .join (unicode_ (self .data [m ]) for m in mentions ), ")" )
139
151
140
152
###################################
141
153
####### MAIN COREF FUNCTIONS ######
@@ -150,11 +162,10 @@ def run_coref_on_mentions(self, mentions):
150
162
for mention_idx , ant_list in self .data .get_candidate_pairs (mentions , self .max_dist , self .max_dist_match ):
151
163
mention = self .data [mention_idx ]
152
164
feats_ , ana_feats = self .data .get_single_mention_features (mention )
153
- anaphoricity_score = self .coref_model .get_single_mention_score (mention . embedding , ana_feats )
154
- self .mentions_single_scores [mention_idx ] = anaphoricity_score
165
+ single_score = self .coref_model .get_single_mention_score (mention , ana_feats )
166
+ self .mentions_single_scores [mention_idx ] = single_score
155
167
self .mentions_single_features [mention_idx ] = {"spansEmbeddings" : mention .spans_embeddings_ , "wordsEmbeddings" : mention .words_embeddings_ , "features" : feats_ }
156
-
157
- best_score = anaphoricity_score - 50 * (self .greedyness - 0.5 )
168
+ best_score = single_score - 50 * (self .greedyness - 0.5 )
158
169
for ant_idx in ant_list :
159
170
antecedent = self .data [ant_idx ]
160
171
feats_ , pwf = self .data .get_pair_mentions_features (antecedent , mention )
@@ -164,7 +175,6 @@ def run_coref_on_mentions(self, mentions):
164
175
"antecedentWordsEmbeddings" : antecedent .words_embeddings_ ,
165
176
"mentionSpansEmbeddings" : mention .spans_embeddings_ ,
166
177
"mentionWordsEmbeddings" : mention .words_embeddings_ }
167
-
168
178
if score > best_score :
169
179
best_score = score
170
180
best_ant [mention_idx ] = ant_idx
@@ -173,25 +183,29 @@ def run_coref_on_mentions(self, mentions):
173
183
self ._merge_coreference_clusters (best_ant [mention_idx ], mention_idx )
174
184
return (n_ant , best_ant )
175
185
176
- def run_coref_on_utterances (self , last_utterances_added = False , follow_chains = True ):
186
+ def run_coref_on_utterances (self , last_utterances_added = False , follow_chains = True , debug = False ):
177
187
''' Run the coreference model on some utterances
178
188
179
189
Arg:
180
190
last_utterances_added: run the coreference model over the last utterances added to the data
181
191
follow_chains: follow coreference chains over previous utterances
182
192
'''
193
+ if debug : print ("== run_coref_on_utterances == start" )
183
194
self ._prepare_clusters ()
195
+ if debug : self .display_clusters ()
184
196
mentions = list (self .data .get_candidate_mentions (last_utterances_added = last_utterances_added ))
185
197
n_ant , antecedents = self .run_coref_on_mentions (mentions )
186
198
mentions = antecedents .values ()
187
- if follow_chains and n_ant > 0 :
199
+ if follow_chains and last_utterances_added and n_ant > 0 :
188
200
i = 0
189
201
while i < MAX_FOLLOW_UP :
190
202
i += 1
191
203
n_ant , antecedents = self .run_coref_on_mentions (mentions )
192
204
mentions = antecedents .values ()
193
205
if n_ant == 0 :
194
206
break
207
+ if debug : self .display_clusters ()
208
+ if debug : print ("== run_coref_on_utterances == end" )
195
209
196
210
def one_shot_coref (self , utterances , utterances_speakers_id = None , context = None ,
197
211
context_speakers_id = None , speakers_names = None ):
@@ -236,7 +250,7 @@ def continuous_coref(self, utterances, utterances_speakers_id=None, speakers_nam
236
250
237
251
def get_utterances (self , last_utterances_added = True ):
238
252
''' Retrieve the list of parsed uterrances'''
239
- if last_utterances_added :
253
+ if last_utterances_added and len ( self . data . last_utterances_loaded ) :
240
254
return [self .data .utterances [idx ] for idx in self .data .last_utterances_loaded ]
241
255
else :
242
256
return self .data .utterances
@@ -272,9 +286,10 @@ def get_scores(self):
272
286
return {"single_scores" : self .mentions_single_scores ,
273
287
"pair_scores" : self .mentions_pairs_scores }
274
288
275
- def get_clusters (self , remove_singletons = True , use_no_coref_list = True ):
289
+ def get_clusters (self , remove_singletons = False , use_no_coref_list = False ):
276
290
''' Retrieve cleaned clusters'''
277
291
clusters = self .clusters
292
+ mention_to_cluster = self .mention_to_cluster
278
293
remove_id = []
279
294
if use_no_coref_list :
280
295
for key , mentions in clusters .items ():
@@ -289,7 +304,7 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
289
304
for key , mentions in clusters .items ():
290
305
if self .data .mentions [key ].lower_ in NO_COREF_LIST :
291
306
remove_id .append (key )
292
- self . mention_to_cluster [key ] = None
307
+ mention_to_cluster [key ] = None
293
308
if mentions :
294
309
added [mentions [0 ]] = mentions
295
310
for rem in remove_id :
@@ -301,11 +316,11 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
301
316
for key , mentions in clusters .items ():
302
317
if len (mentions ) == 1 :
303
318
remove_id .append (key )
304
- self . mention_to_cluster [key ] = None
319
+ mention_to_cluster [key ] = None
305
320
for rem in remove_id :
306
321
del clusters [rem ]
307
322
308
- return clusters
323
+ return clusters , mention_to_cluster
309
324
310
325
def get_most_representative (self , last_utterances_added = True , use_no_coref_list = True ):
311
326
'''
@@ -314,7 +329,7 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
314
329
Return:
315
330
Dictionnary of {original_mention: most_representative_resolved_mention, ...}
316
331
'''
317
- clusters = self .get_clusters (remove_singletons = True , use_no_coref_list = use_no_coref_list )
332
+ clusters , _ = self .get_clusters (remove_singletons = True , use_no_coref_list = use_no_coref_list )
318
333
coreferences = {}
319
334
for key in self .data .get_candidate_mentions (last_utterances_added = last_utterances_added ):
320
335
if self .mention_to_cluster [key ] is None :
@@ -333,3 +348,19 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
333
348
representative = mention
334
349
335
350
return coreferences
351
+
352
+ if __name__ == '__main__' :
353
+ coref = Coref (use_no_coref_list = False )
354
+ if len (sys .argv ) > 1 :
355
+ sent = sys .argv [1 ]
356
+ coref .one_shot_coref (sent )
357
+ else :
358
+ coref .one_shot_coref (u"Yes, I noticed that many friends, around me received it. It seems that almost everyone received this SMS." )#u"My sister has a dog. She loves him.")
359
+ mentions = coref .get_mentions ()
360
+ print (mentions )
361
+
362
+ utterances = coref .get_utterances ()
363
+ print (utterances )
364
+
365
+ resolved_utterance_text = coref .get_resolved_utterances ()
366
+ print (resolved_utterance_text )
0 commit comments