@@ -142,7 +142,7 @@ def test_similarity_unseen_docs(self):
142142 model .build_vocab (corpus )
143143 self .assertTrue (model .docvecs .similarity_unseen_docs (model , rome_str , rome_str ) > model .docvecs .similarity_unseen_docs (model , rome_str , car_str ))
144144
145- def model_sanity (self , model ):
145+ def model_sanity (self , model , keep_training = True ):
146146 """Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
147147 fire1 = 0 # doc 0 sydney fires
148148 fire2 = 8 # doc 8 sydney fires
@@ -179,6 +179,12 @@ def model_sanity(self, model):
179179 # fire docs should be closer than fire-tennis
180180 self .assertTrue (model .docvecs .similarity (fire1 , fire2 ) > model .docvecs .similarity (fire1 , tennis1 ))
181181
182+ # keep training after save
183+ if keep_training :
184+ model .save (testfile ())
185+ loaded = doc2vec .Doc2Vec .load (testfile ())
186+ loaded .train (sentences )
187+
182188 def test_training (self ):
183189 """Test doc2vec training."""
184190 corpus = DocsLeeCorpus ()
@@ -316,10 +322,10 @@ def test_delete_temporary_training_data(self):
316322 model .delete_temporary_training_data (keep_doctags_vectors = True , keep_inference = True )
317323 self .assertTrue (model .docvecs and hasattr (model .docvecs , 'doctag_syn0' ))
318324 self .assertTrue (hasattr (model , 'syn1' ))
319- self .model_sanity (model )
325+ self .model_sanity (model , keep_training = False )
320326 model = doc2vec .Doc2Vec (list_corpus , dm = 1 , dm_mean = 1 , size = 24 , window = 4 , hs = 0 , negative = 1 , alpha = 0.05 , min_count = 2 , iter = 20 )
321327 model .delete_temporary_training_data (keep_doctags_vectors = True , keep_inference = True )
322- self .model_sanity (model )
328+ self .model_sanity (model , keep_training = False )
323329 self .assertTrue (hasattr (model , 'syn1neg' ))
324330
325331 @log_capture ()
0 commit comments