Skip to content

Commit db5b133

Browse files
Allow for custom corpus in LSA primitive (#148)
* allow for custom corpus in LSA primitive * update release notes * update new test answers * update test * more test cleanup * update _create_trainer * user arpack instead of randomized * update doctest * remove doctest * update algo again * add back doctest * update docstring, algorithm and tests * lint fix * Update nlp_primitives/lsa.py Co-authored-by: Gaurav Sheni <gvsheni@gmail.com> * catch bad SVD algorithm input * update docstring to include args * fix doc link Co-authored-by: Gaurav Sheni <gvsheni@gmail.com>
1 parent 31863e5 commit db5b133

File tree

3 files changed

+89
-10
lines changed

3 files changed

+89
-10
lines changed

docs/source/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Changelog
55
Future Release
66
==============
77
* Enhancements
8+
* Allow users to optionally pass in a custom corpus to use with the LSA primitive (:pr:`148`)
89
* Fixes
910
* Fix bug in ``CountString`` with null values (:pr:`154`)
1011
* Fix a bug with nltk data was not included in package (:pr:`157`)
@@ -27,7 +28,7 @@ v2.6.0 Jun 16, 2022
2728
* Fixed unit tests workflow test choice logic (:pr:`151`)
2829

2930
Thanks to the following people for contributing to this release:
30-
:user:`gsheni`, :user:`rwedge`
31+
:user:`gsheni`, :user:`rwedge`, :user:`thehomebrewnerd`
3132

3233
v2.5.0 Apr 7, 2022
3334
==================

nlp_primitives/lsa.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,26 @@ class LSA(TransformPrimitive):
1919
Given a list of strings, transforms those strings using tf-idf and single
2020
value decomposition to go from a sparse matrix to a compact matrix with two
2121
values for each string. These values represent that Latent Semantic Analysis
22-
of each string. These values will represent their context with respect to
23-
(nltk's gutenberg corpus.)[https://www.nltk.org/book/ch02.html#gutenberg-corpus]
22+
of each string. By default these values will represent their context with respect to
23+
`nltk's gutenberg corpus. <https://www.nltk.org/book/ch02.html#gutenberg-corpus>`_
24+
Users can optionally pass in a custom corpus when initializing the primitive
25+
by specifying the corpus values in a list with the corpus parameter.
2426
2527
If a string is missing, return `NaN`.
2628
29+
Note: If a small custom corpus is used, the output of the primitive may vary
30+
depending on the computer architecture being used (Linux, MacOS, Windows). This
31+
is especially true when using the default "randomized" algorithm for the
32+
TruncatedSVD component.
33+
34+
Args:
35+
random_seed (int, optional): The random seed value to use for the call to TruncatedSVD.
36+
Will default to 0 if not specified.
37+
custom_corpus (list[str], optional): A list of strings to use as a custom corpus. Will
38+
default to the NLTK Gutenberg corpus if not specified.
39+
algorithm (str, optional): The algorithm to use for the call to TruncatedSVD. Should be either
40+
"randomized" or "arpack". Will default to "randomized" if not specified.
41+
2742
Examples:
2843
>>> lsa = LSA()
2944
>>> x = ["he helped her walk,", "me me me eat food", "the sentence doth long"]
@@ -32,8 +47,8 @@ class LSA(TransformPrimitive):
3247
>>> res
3348
[[0.01, 0.01, 0.01], [0.0, 0.0, 0.01]]
3449
35-
Now, if we change the values of the input corpus, to something that better resembles
36-
the given text, the same given input text will result in a different, more discerning,
50+
Now, if we change the values of the input text, to something that better resembles
51+
the given corpus, the same given input text will result in a different, more discerning,
3752
output. Also, NaN values are handled, as well as strings without words.
3853
3954
>>> lsa = LSA()
@@ -43,25 +58,48 @@ class LSA(TransformPrimitive):
4358
>>> res
4459
[[0.02, 0.0, nan, 0.0], [0.02, 0.0, nan, 0.0]]
4560
61+
Users can optionally also pass in a custom corpus and specify the algorithm to use
62+
for the TruncatedSVD component used by the primitive.
63+
64+
>>> custom_corpus = ["dogs ate food", "she ate pineapple", "hello"]
65+
>>> lsa = LSA(corpus=custom_corpus, algorithm="arpack")
66+
>>> x = ["The dogs ate food.",
67+
... "She ate a pineapple",
68+
... "Consume Electrolytes, he told me.",
69+
... "Hello",]
70+
>>> res = lsa(x).tolist()
71+
>>> for i in range(len(res)): res[i] = [abs(round(x, 2)) for x in res[i]]
72+
>>> res
73+
[[0.68, 0.78, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
4674
"""
4775

4876
name = "lsa"
4977
input_types = [ColumnSchema(logical_type=NaturalLanguage)]
5078
return_type = ColumnSchema(logical_type=Double, semantic_tags={"numeric"})
5179
default_value = 0
5280

53-
def __init__(self, random_seed=0):
54-
# TODO: allow user to use own corpus
81+
def __init__(self, random_seed=0, corpus=None, algorithm=None):
5582
self.number_output_features = 2
5683
self.n = 2
5784
self.trainer = None
5885
self.random_seed = random_seed
86+
self.corpus = corpus
87+
self.algorithm = algorithm or "randomized"
88+
if self.algorithm not in ["randomized", "arpack"]:
89+
raise ValueError(
90+
"TruncatedSVD algorithm must be either 'randomized' or 'arpack'"
91+
)
5992

6093
def _create_trainer(self):
61-
gutenberg = nltk.corpus.gutenberg.sents()
62-
svd = TruncatedSVD(random_state=self.random_seed)
94+
if self.corpus is None:
95+
gutenberg = nltk.corpus.gutenberg.sents()
96+
corpus = [" ".join(sent) for sent in gutenberg]
97+
else:
98+
corpus = self.corpus
99+
svd = TruncatedSVD(random_state=self.random_seed, algorithm=self.algorithm)
100+
63101
self.trainer = make_pipeline(TfidfVectorizer(), svd)
64-
self.trainer.fit([" ".join(sent) for sent in gutenberg])
102+
self.trainer.fit(corpus)
65103

66104
def get_function(self):
67105
if self.trainer is None:

nlp_primitives/tests/test_lsa.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import nltk
12
import numpy as np
23
import pandas as pd
4+
import pytest
35

46
from ..lsa import LSA
57
from .test_utils import PrimitiveT, find_applicable_primitives, valid_dfs
@@ -42,6 +44,39 @@ def test_strings(self):
4244
decimal=2,
4345
)
4446

47+
def test_strings_custom_corpus(self):
48+
x = pd.Series(
49+
[
50+
"The dogs ate food.",
51+
"She ate a pineapple",
52+
"Consume Electrolytes, he told me.",
53+
"Hello",
54+
]
55+
)
56+
# Create a new corpus using only the first 10000 elements from Gutenberg
57+
gutenberg = nltk.corpus.gutenberg.sents()
58+
corpus = [" ".join(sent) for sent in gutenberg]
59+
corpus = corpus[:10000]
60+
primitive_func = self.primitive(corpus=corpus).get_function()
61+
62+
answers = pd.Series(
63+
[
64+
[0.03858566832087156, 0.04979961879358504, 0.013042488281432613, 0.0],
65+
[
66+
-0.0010495388842080527,
67+
-0.0011128696986250912,
68+
0.001556757056617563,
69+
0.0,
70+
],
71+
]
72+
)
73+
results = primitive_func(x)
74+
np.testing.assert_array_almost_equal(
75+
np.concatenate(([np.array(answers[0])], [np.array(answers[1])]), axis=0),
76+
np.concatenate(([np.array(results[0])], [np.array(results[1])]), axis=0),
77+
decimal=2,
78+
)
79+
4580
def test_nan(self):
4681
x = pd.Series([np.nan, "#;.<", "This IS a STRING."])
4782
primitive_func = self.primitive().get_function()
@@ -69,3 +104,8 @@ def test_with_featuretools(self, es):
69104
valid_dfs(
70105
es, aggregation, transform, self.primitive.name.upper(), multi_output=True
71106
)
107+
108+
def test_bad_algorithm_input_value(self):
109+
err_message = "TruncatedSVD algorithm must be either 'randomized' or 'arpack'"
110+
with pytest.raises(ValueError, match=err_message):
111+
LSA(algorithm="bad_algo")

0 commit comments

Comments
 (0)