Skip to content

Commit 4f0e2ae

Browse files
dkim010tmylk
authored andcommitted
most_similar_cosmul bug fix (#1177)
* most_similar_cosmul bug fix Signed-off-by: Dongwon Kim <dongwon.kim@navercorp.com> * update most_similar_cosmul test scripts bug fix Signed-off-by: Dongwon Kim <dongwon.kim@navercorp.com> * make most_similar_cosmul() work with vectors * add a test case for most_similar_cosmul Signed-off-by: Dongwon Kim <dongwon.kim@navercorp.com>
1 parent a458658 commit 4f0e2ae

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

gensim/models/keyedvectors.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,21 @@ def most_similar_cosmul(self, positive=[], negative=[], topn=10):
455455
# allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog'])
456456
positive = [positive]
457457

458+
all_words = set([self.vocab[word].index for word in positive+negative
459+
if not isinstance(word, ndarray) and word in self.vocab])
460+
461+
positive = [
462+
self.word_vec(word, use_norm=True) if isinstance(word, string_types) else word
463+
for word in positive
464+
]
465+
negative = [
466+
self.word_vec(word, use_norm=True) if isinstance(word, string_types) else word
467+
for word in negative
468+
]
458469

459-
positive = [self.word_vec(word, use_norm=True) for word in positive]
460-
negative = [self.word_vec(word, use_norm=True) for word in negative]
461470
if not positive:
462471
raise ValueError("cannot compute similarity with no input")
463472

464-
all_words = set([self.vocab[word].index for word in positive+negative if word in self.vocab])
465-
466473
# equation (4) of Levy & Goldberg "Linguistic Regularities...",
467474
# with distances shifted to [0,1] per footnote (7)
468475
pos_dists = [((1 + dot(self.syn0norm, term)) / 2) for term in positive]

gensim/test/test_fasttext_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ def testMostSimilar(self):
158158
def testMostSimilarCosmul(self):
159159
"""Test most_similar_cosmul for in-vocab and out-of-vocab words"""
160160
# In vocab, sanity check
161-
self.assertEqual(len(self.test_model.most_similar(positive=['the', 'and'], topn=5)), 5)
162-
self.assertEqual(self.test_model.most_similar('the'), self.test_model.most_similar(positive=['the']))
161+
self.assertEqual(len(self.test_model.most_similar_cosmul(positive=['the', 'and'], topn=5)), 5)
162+
self.assertEqual(self.test_model.most_similar_cosmul('the'), self.test_model.most_similar_cosmul(positive=['the']))
163163
# Out of vocab check
164-
self.assertEqual(len(self.test_model.most_similar(['night', 'nights'], topn=5)), 5)
165-
self.assertEqual(self.test_model.most_similar('nights'), self.test_model.most_similar(positive=['nights']))
164+
self.assertEqual(len(self.test_model.most_similar_cosmul(['night', 'nights'], topn=5)), 5)
165+
self.assertEqual(self.test_model.most_similar_cosmul('nights'), self.test_model.most_similar_cosmul(positive=['nights']))
166166

167167
def testLookup(self):
168168
"""Tests word vector lookup for in-vocab and out-of-vocab words"""
@@ -218,4 +218,4 @@ def testHash(self):
218218

219219
if __name__ == '__main__':
220220
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
221-
unittest.main()
221+
unittest.main()

gensim/test/test_word2vec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,17 @@ def test_cbow_neg(self):
449449
min_count=5, iter=10, workers=2, sample=0)
450450
self.model_sanity(model)
451451

452+
def test_cosmul(self):
453+
model = word2vec.Word2Vec(sentences, size=2, min_count=1, hs=1, negative=0)
454+
sims = model.most_similar_cosmul('graph', topn=10)
455+
# self.assertTrue(sims[0][0] == 'trees', sims) # most similar
456+
457+
# test querying for "most similar" by vector
458+
graph_vector = model.wv.syn0norm[model.wv.vocab['graph'].index]
459+
sims2 = model.most_similar_cosmul(positive=[graph_vector], topn=11)
460+
sims2 = [(w, sim) for w, sim in sims2 if w != 'graph'] # ignore 'graph' itself
461+
self.assertEqual(sims, sims2)
462+
452463
def testTrainingCbow(self):
453464
"""Test CBOW word2vec training."""
454465
# to test training, make the corpus larger by repeating its sentences over and over

0 commit comments

Comments
 (0)