@@ -19,11 +19,26 @@ class LSA(TransformPrimitive):
19
19
Given a list of strings, transforms those strings using tf-idf and single
20
20
value decomposition to go from a sparse matrix to a compact matrix with two
21
21
values for each string. These values represent that Latent Semantic Analysis
22
- of each string. These values will represent their context with respect to
23
- (nltk's gutenberg corpus.)[https://www.nltk.org/book/ch02.html#gutenberg-corpus]
22
+ of each string. By default these values will represent their context with respect to
23
+ `nltk's gutenberg corpus. <https://www.nltk.org/book/ch02.html#gutenberg-corpus>`_
24
+ Users can optionally pass in a custom corpus when initializing the primitive
25
+ by specifying the corpus values in a list with the corpus parameter.
24
26
25
27
If a string is missing, return `NaN`.
26
28
29
+ Note: If a small custom corpus is used, the output of the primitive may vary
30
+ depending on the computer architecture being used (Linux, MacOS, Windows). This
31
+ is especially true when using the default "randomized" algorithm for the
32
+ TruncatedSVD component.
33
+
34
+ Args:
35
+ random_seed (int, optional): The random seed value to use for the call to TruncatedSVD.
36
+ Will default to 0 if not specified.
37
+ custom_corpus (list[str], optional): A list of strings to use as a custom corpus. Will
38
+ default to the NLTK Gutenberg corpus if not specified.
39
+ algorithm (str, optional): The algorithm to use for the call to TruncatedSVD. Should be either
40
+ "randomized" or "arpack". Will default to "randomized" if not specified.
41
+
27
42
Examples:
28
43
>>> lsa = LSA()
29
44
>>> x = ["he helped her walk,", "me me me eat food", "the sentence doth long"]
@@ -32,8 +47,8 @@ class LSA(TransformPrimitive):
32
47
>>> res
33
48
[[0.01, 0.01, 0.01], [0.0, 0.0, 0.01]]
34
49
35
- Now, if we change the values of the input corpus , to something that better resembles
36
- the given text , the same given input text will result in a different, more discerning,
50
+ Now, if we change the values of the input text , to something that better resembles
51
+ the given corpus , the same given input text will result in a different, more discerning,
37
52
output. Also, NaN values are handled, as well as strings without words.
38
53
39
54
>>> lsa = LSA()
@@ -43,25 +58,48 @@ class LSA(TransformPrimitive):
43
58
>>> res
44
59
[[0.02, 0.0, nan, 0.0], [0.02, 0.0, nan, 0.0]]
45
60
61
+ Users can optionally also pass in a custom corpus and specify the algorithm to use
62
+ for the TruncatedSVD component used by the primitive.
63
+
64
+ >>> custom_corpus = ["dogs ate food", "she ate pineapple", "hello"]
65
+ >>> lsa = LSA(corpus=custom_corpus, algorithm="arpack")
66
+ >>> x = ["The dogs ate food.",
67
+ ... "She ate a pineapple",
68
+ ... "Consume Electrolytes, he told me.",
69
+ ... "Hello",]
70
+ >>> res = lsa(x).tolist()
71
+ >>> for i in range(len(res)): res[i] = [abs(round(x, 2)) for x in res[i]]
72
+ >>> res
73
+ [[0.68, 0.78, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
46
74
"""
47
75
48
76
name = "lsa"
49
77
input_types = [ColumnSchema (logical_type = NaturalLanguage )]
50
78
return_type = ColumnSchema (logical_type = Double , semantic_tags = {"numeric" })
51
79
default_value = 0
52
80
53
- def __init__ (self , random_seed = 0 ):
54
- # TODO: allow user to use own corpus
81
+ def __init__ (self , random_seed = 0 , corpus = None , algorithm = None ):
55
82
self .number_output_features = 2
56
83
self .n = 2
57
84
self .trainer = None
58
85
self .random_seed = random_seed
86
+ self .corpus = corpus
87
+ self .algorithm = algorithm or "randomized"
88
+ if self .algorithm not in ["randomized" , "arpack" ]:
89
+ raise ValueError (
90
+ "TruncatedSVD algorithm must be either 'randomized' or 'arpack'"
91
+ )
59
92
60
93
def _create_trainer (self ):
61
- gutenberg = nltk .corpus .gutenberg .sents ()
62
- svd = TruncatedSVD (random_state = self .random_seed )
94
+ if self .corpus is None :
95
+ gutenberg = nltk .corpus .gutenberg .sents ()
96
+ corpus = [" " .join (sent ) for sent in gutenberg ]
97
+ else :
98
+ corpus = self .corpus
99
+ svd = TruncatedSVD (random_state = self .random_seed , algorithm = self .algorithm )
100
+
63
101
self .trainer = make_pipeline (TfidfVectorizer (), svd )
64
- self .trainer .fit ([ " " . join ( sent ) for sent in gutenberg ] )
102
+ self .trainer .fit (corpus )
65
103
66
104
def get_function (self ):
67
105
if self .trainer is None :
0 commit comments