Skip to content

Commit a072e56

Browse files
authored
ENH add normalize to LDA.transform (scikit-learn#30097)
1 parent ed20290 commit a072e56

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`~sklearn.decomposition.LatentDirichletAllocation` now has a
2+
``normalize`` parameter in ``transform`` and ``fit_transform`` methods
3+
to control whether the document topic distribution is normalized.
4+
By `Adrin Jalali`_.

sklearn/decomposition/_lda.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _unnormalized_transform(self, X):
723723

724724
return doc_topic_distr
725725

726-
def transform(self, X):
726+
def transform(self, X, *, normalize=True):
727727
"""Transform data X according to the fitted model.
728728
729729
.. versionchanged:: 0.18
@@ -734,6 +734,9 @@ def transform(self, X):
734734
X : {array-like, sparse matrix} of shape (n_samples, n_features)
735735
Document word matrix.
736736
737+
normalize : bool, default=True
738+
Whether to normalize the document topic distribution.
739+
737740
Returns
738741
-------
739742
doc_topic_distr : ndarray of shape (n_samples, n_components)
@@ -744,9 +747,35 @@ def transform(self, X):
744747
X, reset_n_features=False, whom="LatentDirichletAllocation.transform"
745748
)
746749
doc_topic_distr = self._unnormalized_transform(X)
747-
doc_topic_distr /= doc_topic_distr.sum(axis=1)[:, np.newaxis]
750+
if normalize:
751+
doc_topic_distr /= doc_topic_distr.sum(axis=1)[:, np.newaxis]
748752
return doc_topic_distr
749753

754+
def fit_transform(self, X, y=None, *, normalize=True):
755+
"""
756+
Fit to data, then transform it.
757+
758+
Fits transformer to `X` and `y` and returns a transformed version of `X`.
759+
760+
Parameters
761+
----------
762+
X : array-like of shape (n_samples, n_features)
763+
Input samples.
764+
765+
y : array-like of shape (n_samples,) or (n_samples, n_outputs), \
766+
default=None
767+
Target values (None for unsupervised transformations).
768+
769+
normalize : bool, default=True
770+
Whether to normalize the document topic distribution in `transform`.
771+
772+
Returns
773+
-------
774+
X_new : ndarray array of shape (n_samples, n_features_new)
775+
Transformed array.
776+
"""
777+
return self.fit(X, y).transform(X, normalize=normalize)
778+
750779
def _approx_bound(self, X, doc_topic_distr, sub_sampling):
751780
"""Estimate the variational bound.
752781

sklearn/decomposition/tests/test_online_lda.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_lda_dense_input(csr_container):
132132

133133
def test_lda_transform():
134134
# Test LDA transform.
135-
# Transform result cannot be negative and should be normalized
135+
# Transform result cannot be negative and should be normalized by default
136136
rng = np.random.RandomState(0)
137137
X = rng.randint(5, size=(20, 10))
138138
n_components = 3
@@ -141,6 +141,11 @@ def test_lda_transform():
141141
assert (X_trans > 0.0).any()
142142
assert_array_almost_equal(np.sum(X_trans, axis=1), np.ones(X_trans.shape[0]))
143143

144+
X_trans_unnormalized = lda.transform(X, normalize=False)
145+
assert_array_almost_equal(
146+
X_trans, X_trans_unnormalized / X_trans_unnormalized.sum(axis=1)[:, np.newaxis]
147+
)
148+
144149

145150
@pytest.mark.parametrize("method", ("online", "batch"))
146151
def test_lda_fit_transform(method):

0 commit comments

Comments
 (0)