|
280 | 280 | """ |
281 | 281 |
|
282 | 282 | import logging |
| 283 | +import os |
283 | 284 |
|
284 | 285 | import numpy as np |
285 | 286 | from numpy import ones, vstack, float32 as REAL, sum as np_sum |
286 | 287 | import six |
| 288 | +from collections import Iterable |
287 | 289 |
|
288 | 290 | import gensim.models._fasttext_bin |
289 | 291 |
|
@@ -901,6 +903,19 @@ def train(self, sentences=None, corpus_file=None, total_examples=None, total_wor |
901 | 903 | >>> model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs) |
902 | 904 |
|
903 | 905 | """ |
| 906 | + |
| 907 | + if corpus_file is None and sentences is None: |
| 908 | + raise TypeError("Either one of corpus_file or sentences value must be provided") |
| 909 | + |
| 910 | + if corpus_file is not None and sentences is not None: |
| 911 | + raise TypeError("Both corpus_file and sentences must not be provided at the same time") |
| 912 | + |
| 913 | + if sentences is None and not os.path.isfile(corpus_file): |
| 914 | + raise TypeError("Parameter corpus_file must be a valid path to a file, got %r instead" % corpus_file) |
| 915 | + |
| 916 | + if sentences is not None and not isinstance(sentences, Iterable): |
| 917 | + raise TypeError("sentences must be an iterable of list, got %r instead" % sentences) |
| 918 | + |
904 | 919 | super(FastText, self).train( |
905 | 920 | sentences=sentences, corpus_file=corpus_file, total_examples=total_examples, total_words=total_words, |
906 | 921 | epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, |
@@ -1023,30 +1038,22 @@ def load(cls, *args, **kwargs): |
1023 | 1038 | """ |
1024 | 1039 | try: |
1025 | 1040 | model = super(FastText, cls).load(*args, **kwargs) |
1026 | | - if hasattr(model.wv, 'hash2index'): |
1027 | | - gensim.models.keyedvectors._rollback_optimization(model.wv) |
1028 | 1041 |
|
1029 | 1042 | if not hasattr(model.trainables, 'vectors_vocab_lockf') and hasattr(model.wv, 'vectors_vocab'): |
1030 | 1043 | model.trainables.vectors_vocab_lockf = ones(model.wv.vectors_vocab.shape, dtype=REAL) |
1031 | 1044 | if not hasattr(model.trainables, 'vectors_ngrams_lockf') and hasattr(model.wv, 'vectors_ngrams'): |
1032 | 1045 | model.trainables.vectors_ngrams_lockf = ones(model.wv.vectors_ngrams.shape, dtype=REAL) |
1033 | 1046 |
|
1034 | | - if not hasattr(model.wv, 'compatible_hash'): |
1035 | | - logger.warning( |
1036 | | - "This older model was trained with a buggy hash function. " |
1037 | | - "The model will continue to work, but consider training it " |
1038 | | - "from scratch." |
1039 | | - ) |
1040 | | - model.wv.compatible_hash = False |
1041 | | - |
1042 | 1047 | if not hasattr(model.wv, 'bucket'): |
1043 | 1048 | model.wv.bucket = model.trainables.bucket |
1044 | | - |
1045 | | - return model |
1046 | 1049 | except AttributeError: |
1047 | 1050 | logger.info('Model saved using code from earlier Gensim Version. Re-loading old model in a compatible way.') |
1048 | 1051 | from gensim.models.deprecated.fasttext import load_old_fasttext |
1049 | | - return load_old_fasttext(*args, **kwargs) |
| 1052 | + model = load_old_fasttext(*args, **kwargs) |
| 1053 | + |
| 1054 | + gensim.models.keyedvectors._try_upgrade(model.wv) |
| 1055 | + |
| 1056 | + return model |
1050 | 1057 |
|
1051 | 1058 | @deprecated("Method will be removed in 4.0.0, use self.wv.accuracy() instead") |
1052 | 1059 | def accuracy(self, questions, restrict_vocab=30000, most_similar=None, case_insensitive=True): |
|
0 commit comments