Skip to content

Commit f550b06

Browse files
authored
[fix] correct skips in design rule checks (test_common.py) caused by fragile whitelist_to_blacklist (#2086)
* Update test_common.py * Update test_common.py * black corrections * fixes * swap back to dict * handle separate drive * better handle ValueError * return never viable string * observe failures in n_jobs support * force all false * more informative * check _get_backend * reset assert * return functional test into operation * move try catch to generalization * formatting' * missing knn score methods * add comments * add informative xfail text * force full use of fullpath
1 parent c16f4e6 commit f550b06

File tree

1 file changed

+116
-33
lines changed

1 file changed

+116
-33
lines changed

sklearnex/tests/test_common.py

Lines changed: 116 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
import importlib
17+
import importlib.util
1818
import os
1919
import re
2020
import sys
@@ -44,12 +44,60 @@
4444
"svm" + os.sep + "_common.py",
4545
]
4646

47-
_DESIGN_RULE_VIOLATIONS = [
48-
"PCA-fit_transform-call_validate_data", # calls both "fit" and "transform"
49-
"IncrementalEmpiricalCovariance-score-call_validate_data", # must call clone of itself
50-
"SVC(probability=True)-fit-call_validate_data", # SVC fit can use sklearn estimator
51-
"NuSVC(probability=True)-fit-call_validate_data", # NuSVC fit can use sklearn estimator
52-
]
47+
_DESIGN_RULE_VIOLATIONS = {
48+
"PCA-fit_transform-call_validate_data": "calls both 'fit' and 'transform'",
49+
"IncrementalEmpiricalCovariance-score-call_validate_data": "must call clone of itself",
50+
"SVC(probability=True)-fit-call_validate_data": "SVC fit can use sklearn estimator",
51+
"NuSVC(probability=True)-fit-call_validate_data": "NuSVC fit can use sklearn estimator",
52+
"LogisticRegression-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
53+
"LogisticRegression-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
54+
"LogisticRegression-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
55+
"LogisticRegression-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
56+
"LogisticRegression-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
57+
"KNeighborsClassifier-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
58+
"KNeighborsClassifier-fit-n_jobs_check": "uses daal4py for cpu in onedal",
59+
"KNeighborsClassifier-score-n_jobs_check": "uses daal4py for cpu in onedal",
60+
"KNeighborsClassifier-predict-n_jobs_check": "uses daal4py for cpu in onedal",
61+
"KNeighborsClassifier-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
62+
"KNeighborsClassifier-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
63+
"KNeighborsRegressor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
64+
"KNeighborsRegressor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
65+
"KNeighborsRegressor-score-n_jobs_check": "uses daal4py for cpu in onedal",
66+
"KNeighborsRegressor-predict-n_jobs_check": "uses daal4py for cpu in onedal",
67+
"KNeighborsRegressor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
68+
"NearestNeighbors-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
69+
"NearestNeighbors-fit-n_jobs_check": "uses daal4py for cpu in onedal",
70+
"NearestNeighbors-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
71+
"NearestNeighbors-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
72+
"NearestNeighbors-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
73+
"LocalOutlierFactor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
74+
"LocalOutlierFactor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
75+
"LocalOutlierFactor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
76+
"KNeighborsClassifier(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
77+
"KNeighborsClassifier(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
78+
"KNeighborsClassifier(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
79+
"KNeighborsClassifier(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
80+
"KNeighborsClassifier(algorithm='brute')-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
81+
"KNeighborsClassifier(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
82+
"KNeighborsRegressor(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
83+
"KNeighborsRegressor(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
84+
"KNeighborsRegressor(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
85+
"KNeighborsRegressor(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
86+
"KNeighborsRegressor(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
87+
"NearestNeighbors(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
88+
"NearestNeighbors(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
89+
"NearestNeighbors(algorithm='brute')-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
90+
"NearestNeighbors(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
91+
"NearestNeighbors(algorithm='brute')-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
92+
"LocalOutlierFactor(novelty=True)-fit-n_jobs_check": "uses daal4py for cpu in onedal",
93+
"LocalOutlierFactor(novelty=True)-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
94+
"LocalOutlierFactor(novelty=True)-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
95+
"LogisticRegression(solver='newton-cg')-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
96+
"LogisticRegression(solver='newton-cg')-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
97+
"LogisticRegression(solver='newton-cg')-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
98+
"LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
99+
"LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
100+
}
53101

54102

55103
def test_target_offload_ban():
@@ -78,29 +126,52 @@ def test_target_offload_ban():
78126
assert output == "", f"sklearn versioning is occuring in: \n{output}"
79127

80128

129+
def _fullpath(path):
130+
return os.path.realpath(os.path.expanduser(path))
131+
132+
81133
_TRACE_ALLOW_DICT = {
82-
i: os.path.dirname(importlib.util.find_spec(i).origin)
134+
i: _fullpath(os.path.dirname(importlib.util.find_spec(i).origin))
83135
for i in ["sklearn", "sklearnex", "onedal", "daal4py"]
84136
}
85137

86138

87139
def _whitelist_to_blacklist():
88-
"""block all standard library, builting or site packages which are not
140+
"""block all standard library, built-in or site packages which are not
89141
related to sklearn, daal4py, onedal or sklearnex"""
90142

143+
def _commonpath(inp):
144+
# ValueError generated by os.path.commonpath when it is on a separate drive
145+
try:
146+
return os.path.commonpath(inp)
147+
except ValueError:
148+
return ""
149+
91150
blacklist = []
92151
for path in sys.path:
152+
fpath = _fullpath(path)
93153
try:
94-
if any([path in i for i in _TRACE_ALLOW_DICT.values()]):
95-
blacklist += [
96-
f.path
97-
for f in os.scandir(path)
98-
if f.name not in _TRACE_ALLOW_DICT.keys()
99-
]
100-
else:
101-
blacklist += [path]
154+
# if candidate path is a parent directory to any directory in the whitelist
155+
if any(
156+
[_commonpath([i, fpath]) == fpath for i in _TRACE_ALLOW_DICT.values()]
157+
):
158+
# find all sub-paths which are not in the whitelist and block them
159+
# they should not have a common path that is either the whitelist path
160+
# or the sub-path (meaning one is a parent directory of the either)
161+
for f in os.scandir(fpath):
162+
temppath = _fullpath(f.path)
163+
if all(
164+
[
165+
_commonpath([i, temppath]) not in [i, temppath]
166+
for i in _TRACE_ALLOW_DICT.values()
167+
]
168+
):
169+
blacklist += [temppath]
170+
# add path to blacklist if not a sub path of anything in the whitelist
171+
elif all([_commonpath([i, fpath]) != i for i in _TRACE_ALLOW_DICT.values()]):
172+
blacklist += [fpath]
102173
except FileNotFoundError:
103-
blacklist += [path]
174+
blacklist += [fpath]
104175
return blacklist
105176

106177

@@ -152,7 +223,7 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
152223

153224
# initialize tracer to have a more verbose module naming
154225
# this impacts ignoremods, but it is not used.
155-
monkeypatch.setattr(trace, "_modname", lambda x: x)
226+
monkeypatch.setattr(trace, "_modname", _fullpath)
156227
tracer = trace.Trace(
157228
count=0,
158229
trace=1,
@@ -197,23 +268,28 @@ def call_validate_data(text, estimator, method):
197268
pytest.skip("onedal backend not used in this function")
198269

199270
validate_data = "validate_data" if sklearn_check_version("1.6") else "_validate_data"
200-
try:
201-
assert (
202-
validfuncs.count(validate_data) == 1
203-
), f"sklearn's {validate_data} should be called"
204-
assert (
205-
validfuncs.count("_check_feature_names") == 1
206-
), "estimator should check feature names in validate_data"
207-
except AssertionError:
208-
if "-".join([estimator, method, "call_validate_data"]) in _DESIGN_RULE_VIOLATIONS:
209-
pytest.xfail("Allowed violation of design rules")
210-
else:
211-
raise
271+
272+
assert (
273+
validfuncs.count(validate_data) == 1
274+
), f"sklearn's {validate_data} should be called"
275+
assert (
276+
validfuncs.count("_check_feature_names") == 1
277+
), "estimator should check feature names in validate_data"
212278

213279

214280
def n_jobs_check(text, estimator, method):
215281
"""verify the n_jobs is being set if '_get_backend' or 'to_table' is called"""
216-
count = max([text[0].count(name) for name in ["to_table", "_get_backend"]])
282+
# remove the _get_backend function from sklearnex from considered _get_backend
283+
count = max(
284+
text[0].count("to_table"),
285+
len(
286+
[
287+
i
288+
for i in range(len(text[0]))
289+
if text[0][i] == "_get_backend" and "sklearnex" not in text[2][i]
290+
]
291+
),
292+
)
217293
n_jobs_count = text[0].count("n_jobs_wrapper")
218294

219295
assert bool(count) == bool(
@@ -235,4 +311,11 @@ def n_jobs_check(text, estimator, method):
235311
)
236312
def test_estimator(estimator, method, design_pattern, estimator_trace):
237313
# These tests only apply to sklearnex estimators
238-
design_pattern(estimator_trace, estimator, method)
314+
try:
315+
design_pattern(estimator_trace, estimator, method)
316+
except AssertionError:
317+
key = "-".join([estimator, method, design_pattern.__name__])
318+
if key in _DESIGN_RULE_VIOLATIONS:
319+
pytest.xfail(_DESIGN_RULE_VIOLATIONS[key])
320+
else:
321+
raise

0 commit comments

Comments
 (0)