@@ -918,7 +918,7 @@ def test_property():
918
918
919
919
def test_sample ():
920
920
rng = np .random .RandomState (0 )
921
- rand_data = RandomData (rng , scale = 7 )
921
+ rand_data = RandomData (rng , scale = 7 , n_components = 3 )
922
922
n_features , n_components = rand_data .n_features , rand_data .n_components
923
923
924
924
for covar_type in COVARIANCE_TYPE :
@@ -937,7 +937,8 @@ def test_sample():
937
937
# Just to make sure the class samples correctly
938
938
n_samples = 20000
939
939
X_s , y_s = gmm .sample (n_samples )
940
- for k in range (n_features ):
940
+
941
+ for k in range (n_components ):
941
942
if covar_type == 'full' :
942
943
assert_array_almost_equal (gmm .covariances_ [k ],
943
944
np .cov (X_s [y_s == k ].T ), decimal = 1 )
@@ -954,15 +955,16 @@ def test_sample():
954
955
decimal = 1 )
955
956
956
957
means_s = np .array ([np .mean (X_s [y_s == k ], 0 )
957
- for k in range (n_features )])
958
+ for k in range (n_components )])
958
959
assert_array_almost_equal (gmm .means_ , means_s , decimal = 1 )
959
960
960
- # Check that sizes that are drawn match what is requested
961
- assert_equal (X_s .shape , (n_samples , n_components ))
962
- for sample_size in range (1 , 50 ):
963
- X_s , _ = gmm .sample (sample_size )
964
- assert_equal (X_s .shape , (sample_size , n_components ))
961
+ # Check shapes of sampled data, see
962
+ # https://github.com/scikit-learn/scikit-learn/issues/7701
963
+ assert_equal (X_s .shape , (n_samples , n_features ))
965
964
965
+ for sample_size in range (1 , 100 ):
966
+ X_s , _ = gmm .sample (sample_size )
967
+ assert_equal (X_s .shape , (sample_size , n_features ))
966
968
967
969
968
970
@ignore_warnings (category = ConvergenceWarning )
0 commit comments