Skip to content

Commit 2b3328e

Browse files
authored
make test much more tolerant (#86)
1 parent 777ca8a commit 2b3328e

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

tests/split/test_userwise_split.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,8 @@ def get_n_right_answer(ratio: float, n: Optional[int]) -> int:
127127
for i in index:
128128
nnz_learn = X_learn[i].nonzero()[1].shape[0]
129129
nnz_predict = X_predict[i].nonzero()[1].shape[0]
130-
if ceil:
131-
# ceil(ratio * tot) = n_test
132-
# ratio * tot > n_test - 1 #exact
133-
# or ration * tot <= n_test
134-
assert ratio > (nnz_predict - 1) / (nnz_learn + nnz_predict)
135-
assert ratio <= (nnz_predict) / (nnz_learn + nnz_predict)
136-
else:
137-
# floor(ratio * tot) = n_test
138-
# ratio * tot >= n_test #exact
139-
# or ratio * tot < (n_test + 1)
140-
assert ratio >= (nnz_predict) / (nnz_learn + nnz_predict)
141-
assert ratio < (nnz_predict + 1) / (nnz_learn + nnz_predict)
130+
assert ratio >= (nnz_predict - 1) / (nnz_learn + nnz_predict)
131+
assert ratio <= (nnz_predict + 1) / (nnz_learn + nnz_predict)
142132

143133

144134
def test_user_level_split_val_fixed_n() -> None:

0 commit comments

Comments
 (0)