5
5
import pytest
6
6
from numpy .testing import assert_array_equal
7
7
from scipy .linalg import block_diag
8
- from scipy .sparse import csr_matrix
9
8
from scipy .special import psi
10
9
11
10
from sklearn .decomposition import LatentDirichletAllocation
20
19
assert_array_almost_equal ,
21
20
if_safe_multiprocessing_with_blas ,
22
21
)
22
+ from sklearn .utils .fixes import CSR_CONTAINERS
23
23
24
24
25
- def _build_sparse_mtx ( ):
25
+ def _build_sparse_array ( csr_container ):
26
26
# Create 3 topics and each topic has 3 distinct words.
27
27
# (Each word only belongs to a single topic.)
28
28
n_components = 3
29
29
block = np .full ((3 , 3 ), n_components , dtype = int )
30
30
blocks = [block ] * n_components
31
31
X = block_diag (* blocks )
32
- X = csr_matrix (X )
32
+ X = csr_container (X )
33
33
return (n_components , X )
34
34
35
35
36
- def test_lda_default_prior_params ():
36
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
37
+ def test_lda_default_prior_params (csr_container ):
37
38
# default prior parameter should be `1 / topics`
38
39
# and verbose params should not affect result
39
- n_components , X = _build_sparse_mtx ( )
40
+ n_components , X = _build_sparse_array ( csr_container )
40
41
prior = 1.0 / n_components
41
42
lda_1 = LatentDirichletAllocation (
42
43
n_components = n_components ,
@@ -50,10 +51,11 @@ def test_lda_default_prior_params():
50
51
assert_almost_equal (topic_distr_1 , topic_distr_2 )
51
52
52
53
53
- def test_lda_fit_batch ():
54
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
55
+ def test_lda_fit_batch (csr_container ):
54
56
# Test LDA batch learning_offset (`fit` method with 'batch' learning)
55
57
rng = np .random .RandomState (0 )
56
- n_components , X = _build_sparse_mtx ( )
58
+ n_components , X = _build_sparse_array ( csr_container )
57
59
lda = LatentDirichletAllocation (
58
60
n_components = n_components ,
59
61
evaluate_every = 1 ,
@@ -69,10 +71,11 @@ def test_lda_fit_batch():
69
71
assert tuple (sorted (top_idx )) in correct_idx_grps
70
72
71
73
72
- def test_lda_fit_online ():
74
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
75
+ def test_lda_fit_online (csr_container ):
73
76
# Test LDA online learning (`fit` method with 'online' learning)
74
77
rng = np .random .RandomState (0 )
75
- n_components , X = _build_sparse_mtx ( )
78
+ n_components , X = _build_sparse_array ( csr_container )
76
79
lda = LatentDirichletAllocation (
77
80
n_components = n_components ,
78
81
learning_offset = 10.0 ,
@@ -89,11 +92,12 @@ def test_lda_fit_online():
89
92
assert tuple (sorted (top_idx )) in correct_idx_grps
90
93
91
94
92
- def test_lda_partial_fit ():
95
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
96
+ def test_lda_partial_fit (csr_container ):
93
97
# Test LDA online learning (`partial_fit` method)
94
98
# (same as test_lda_batch)
95
99
rng = np .random .RandomState (0 )
96
- n_components , X = _build_sparse_mtx ( )
100
+ n_components , X = _build_sparse_array ( csr_container )
97
101
lda = LatentDirichletAllocation (
98
102
n_components = n_components ,
99
103
learning_offset = 10.0 ,
@@ -109,10 +113,11 @@ def test_lda_partial_fit():
109
113
assert tuple (sorted (top_idx )) in correct_idx_grps
110
114
111
115
112
- def test_lda_dense_input ():
116
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
117
+ def test_lda_dense_input (csr_container ):
113
118
# Test LDA with dense input.
114
119
rng = np .random .RandomState (0 )
115
- n_components , X = _build_sparse_mtx ( )
120
+ n_components , X = _build_sparse_array ( csr_container )
116
121
lda = LatentDirichletAllocation (
117
122
n_components = n_components , learning_method = "batch" , random_state = rng
118
123
)
@@ -175,9 +180,10 @@ def test_lda_no_component_error():
175
180
176
181
177
182
@if_safe_multiprocessing_with_blas
183
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
178
184
@pytest .mark .parametrize ("method" , ("online" , "batch" ))
179
- def test_lda_multi_jobs (method ):
180
- n_components , X = _build_sparse_mtx ( )
185
+ def test_lda_multi_jobs (method , csr_container ):
186
+ n_components , X = _build_sparse_array ( csr_container )
181
187
# Test LDA batch training with multi CPU
182
188
rng = np .random .RandomState (0 )
183
189
lda = LatentDirichletAllocation (
@@ -196,10 +202,11 @@ def test_lda_multi_jobs(method):
196
202
197
203
198
204
@if_safe_multiprocessing_with_blas
199
- def test_lda_partial_fit_multi_jobs ():
205
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
206
+ def test_lda_partial_fit_multi_jobs (csr_container ):
200
207
# Test LDA online training with multi CPU
201
208
rng = np .random .RandomState (0 )
202
- n_components , X = _build_sparse_mtx ( )
209
+ n_components , X = _build_sparse_array ( csr_container )
203
210
lda = LatentDirichletAllocation (
204
211
n_components = n_components ,
205
212
n_jobs = 2 ,
@@ -240,10 +247,11 @@ def test_lda_preplexity_mismatch():
240
247
241
248
242
249
@pytest .mark .parametrize ("method" , ("online" , "batch" ))
243
- def test_lda_perplexity (method ):
250
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
251
+ def test_lda_perplexity (method , csr_container ):
244
252
# Test LDA perplexity for batch training
245
253
# perplexity should be lower after each iteration
246
- n_components , X = _build_sparse_mtx ( )
254
+ n_components , X = _build_sparse_array ( csr_container )
247
255
lda_1 = LatentDirichletAllocation (
248
256
n_components = n_components ,
249
257
max_iter = 1 ,
@@ -271,10 +279,11 @@ def test_lda_perplexity(method):
271
279
272
280
273
281
@pytest .mark .parametrize ("method" , ("online" , "batch" ))
274
- def test_lda_score (method ):
282
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
283
+ def test_lda_score (method , csr_container ):
275
284
# Test LDA score for batch training
276
285
# score should be higher after each iteration
277
- n_components , X = _build_sparse_mtx ( )
286
+ n_components , X = _build_sparse_array ( csr_container )
278
287
lda_1 = LatentDirichletAllocation (
279
288
n_components = n_components ,
280
289
max_iter = 1 ,
@@ -297,10 +306,11 @@ def test_lda_score(method):
297
306
assert score_2 >= score_1
298
307
299
308
300
- def test_perplexity_input_format ():
309
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
310
+ def test_perplexity_input_format (csr_container ):
301
311
# Test LDA perplexity for sparse and dense input
302
312
# score should be the same for both dense and sparse input
303
- n_components , X = _build_sparse_mtx ( )
313
+ n_components , X = _build_sparse_array ( csr_container )
304
314
lda = LatentDirichletAllocation (
305
315
n_components = n_components ,
306
316
max_iter = 1 ,
@@ -314,9 +324,10 @@ def test_perplexity_input_format():
314
324
assert_almost_equal (perp_1 , perp_2 )
315
325
316
326
317
- def test_lda_score_perplexity ():
327
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
328
+ def test_lda_score_perplexity (csr_container ):
318
329
# Test the relationship between LDA score and perplexity
319
- n_components , X = _build_sparse_mtx ( )
330
+ n_components , X = _build_sparse_array ( csr_container )
320
331
lda = LatentDirichletAllocation (
321
332
n_components = n_components , max_iter = 10 , random_state = 0
322
333
)
@@ -328,10 +339,11 @@ def test_lda_score_perplexity():
328
339
assert_almost_equal (perplexity_1 , perplexity_2 )
329
340
330
341
331
- def test_lda_fit_perplexity ():
342
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
343
+ def test_lda_fit_perplexity (csr_container ):
332
344
# Test that the perplexity computed during fit is consistent with what is
333
345
# returned by the perplexity method
334
- n_components , X = _build_sparse_mtx ( )
346
+ n_components , X = _build_sparse_array ( csr_container )
335
347
lda = LatentDirichletAllocation (
336
348
n_components = n_components ,
337
349
max_iter = 1 ,
@@ -350,10 +362,11 @@ def test_lda_fit_perplexity():
350
362
assert_almost_equal (perplexity1 , perplexity2 )
351
363
352
364
353
- def test_lda_empty_docs ():
365
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
366
+ def test_lda_empty_docs (csr_container ):
354
367
"""Test LDA on empty document (all-zero rows)."""
355
368
Z = np .zeros ((5 , 4 ))
356
- for X in [Z , csr_matrix (Z )]:
369
+ for X in [Z , csr_container (Z )]:
357
370
lda = LatentDirichletAllocation (max_iter = 750 ).fit (X )
358
371
assert_almost_equal (
359
372
lda .components_ .sum (axis = 0 ), np .ones (lda .components_ .shape [1 ])
@@ -376,8 +389,10 @@ def test_dirichlet_expectation():
376
389
)
377
390
378
391
379
- def check_verbosity (verbose , evaluate_every , expected_lines , expected_perplexities ):
380
- n_components , X = _build_sparse_mtx ()
392
+ def check_verbosity (
393
+ verbose , evaluate_every , expected_lines , expected_perplexities , csr_container
394
+ ):
395
+ n_components , X = _build_sparse_array (csr_container )
381
396
lda = LatentDirichletAllocation (
382
397
n_components = n_components ,
383
398
max_iter = 3 ,
@@ -409,13 +424,19 @@ def check_verbosity(verbose, evaluate_every, expected_lines, expected_perplexiti
409
424
(True , 2 , 3 , 1 ),
410
425
],
411
426
)
412
- def test_verbosity (verbose , evaluate_every , expected_lines , expected_perplexities ):
413
- check_verbosity (verbose , evaluate_every , expected_lines , expected_perplexities )
427
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
428
+ def test_verbosity (
429
+ verbose , evaluate_every , expected_lines , expected_perplexities , csr_container
430
+ ):
431
+ check_verbosity (
432
+ verbose , evaluate_every , expected_lines , expected_perplexities , csr_container
433
+ )
414
434
415
435
416
- def test_lda_feature_names_out ():
436
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
437
+ def test_lda_feature_names_out (csr_container ):
417
438
"""Check feature names out for LatentDirichletAllocation."""
418
- n_components , X = _build_sparse_mtx ( )
439
+ n_components , X = _build_sparse_array ( csr_container )
419
440
lda = LatentDirichletAllocation (n_components = n_components ).fit (X )
420
441
421
442
names = lda .get_feature_names_out ()
0 commit comments