Skip to content

Commit cfb4ac6

Browse files
committed
fixed return logic of outliers_gesd method, improved tests (fixes #83)
1 parent 8b8e831 commit cfb4ac6

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

scikit_posthocs/_outliers.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Union, List
22
import numpy as np
3+
from numpy.typing import ArrayLike
34
from scipy.stats import t
45

56

@@ -112,8 +113,7 @@ def outliers_grubbs(
112113
G = val / np.std(arr, ddof=1)
113114
N = len(arr)
114115
result = G > (N - 1) / np.sqrt(N) * np.sqrt(
115-
(t.ppf(1 - alpha / (2 * N), N - 2) ** 2)
116-
/ (N - 2 + t.ppf(1 - alpha / (2 * N), N - 2) ** 2)
116+
(t.ppf(1 - alpha / (2 * N), N - 2) ** 2) / (N - 2 + t.ppf(1 - alpha / (2 * N), N - 2) ** 2)
117117
)
118118

119119
if hypo:
@@ -209,7 +209,7 @@ def tietjen(x_, k_):
209209

210210

211211
def outliers_gesd(
212-
x: Union[List, np.ndarray],
212+
x: ArrayLike,
213213
outliers: int = 5,
214214
hypo: bool = False,
215215
report: bool = False,
@@ -245,8 +245,8 @@ def outliers_gesd(
245245
Returns
246246
-------
247247
np.ndarray
248-
Returns the filtered array if alternative hypo is True, otherwise an
249-
unfiltered (input) array.
248+
If hypo is True, returns a boolean array where True indicates an outlier.
249+
If hypo is False, returns the filtered array with outliers removed.
250250
251251
Notes
252252
-----
@@ -308,7 +308,7 @@ def outliers_gesd(
308308

309309
# Masked values
310310
lms = ms[-1] if len(ms) > 0 else []
311-
ms.append(lms + np.where(data == data_proc[np.argmax(abs_d)])[0].tolist())
311+
ms.append(lms + [np.where(data == data_proc[np.argmax(abs_d)])[0][0]])
312312

313313
# Remove the observation that maximizes |xi − xmean|
314314
data_proc = np.delete(data_proc, np.argmax(abs_d))
@@ -341,16 +341,12 @@ def outliers_gesd(
341341
# Remove masked values
342342
# for which the test statistic is greater
343343
# than the critical value and return the result
344-
345-
if any(rs > ls):
346-
if hypo:
347-
data[:] = False
344+
if hypo:
345+
data = np.zeros(n, dtype=bool)
346+
if any(rs > ls):
348347
data[ms[np.max(np.where(rs > ls))]] = True
349-
# rearrange data so mask is in same order as incoming data
350-
data = np.vstack((data, np.arange(0, data.shape[0])[argsort_index]))
351-
data = data[0, data.argsort()[1,]]
352-
data = data.astype("bool")
353-
else:
354-
data = np.delete(data, ms[np.max(np.where(rs > ls))])
355-
356-
return data
348+
return data
349+
else:
350+
if any(rs > ls):
351+
return np.delete(data, ms[np.max(np.where(rs > ls))])
352+
return data

tests/test_posthocs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,10 @@ def test_outliers_gesd(self):
370370
]
371371
)
372372
self.assertTrue(isinstance(so.outliers_gesd(x, 5, report=True), np.ndarray))
373-
self.assertTrue(np.all(test_results == correct_results))
374-
self.assertTrue(np.all(test_mask_results == correct_mask))
373+
self.assertTrue(np.array_equal(test_results, correct_results))
374+
self.assertTrue(np.array_equal(test_mask_results, correct_mask))
375+
self.assertTrue(np.array_equal(so.outliers_gesd(correct_results, 5, hypo=False), correct_results))
376+
self.assertTrue(np.array_equal(so.outliers_gesd(correct_results, 5, hypo=True), np.zeros_like(correct_results, dtype=bool)))
375377

376378
# Statistical tests
377379
df = sb.load_dataset("exercise")

0 commit comments

Comments
 (0)