14
14
# limitations under the License.
15
15
# ==============================================================================
16
16
17
- import importlib
17
+ import importlib . util
18
18
import os
19
19
import re
20
20
import sys
44
44
"svm" + os .sep + "_common.py" ,
45
45
]
46
46
47
- _DESIGN_RULE_VIOLATIONS = [
48
- "PCA-fit_transform-call_validate_data" , # calls both "fit" and "transform"
49
- "IncrementalEmpiricalCovariance-score-call_validate_data" , # must call clone of itself
50
- "SVC(probability=True)-fit-call_validate_data" , # SVC fit can use sklearn estimator
51
- "NuSVC(probability=True)-fit-call_validate_data" , # NuSVC fit can use sklearn estimator
52
- ]
47
+ _DESIGN_RULE_VIOLATIONS = {
48
+ "PCA-fit_transform-call_validate_data" : "calls both 'fit' and 'transform'" ,
49
+ "IncrementalEmpiricalCovariance-score-call_validate_data" : "must call clone of itself" ,
50
+ "SVC(probability=True)-fit-call_validate_data" : "SVC fit can use sklearn estimator" ,
51
+ "NuSVC(probability=True)-fit-call_validate_data" : "NuSVC fit can use sklearn estimator" ,
52
+ "LogisticRegression-score-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
53
+ "LogisticRegression-fit-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
54
+ "LogisticRegression-predict-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
55
+ "LogisticRegression-predict_log_proba-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
56
+ "LogisticRegression-predict_proba-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
57
+ "KNeighborsClassifier-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
58
+ "KNeighborsClassifier-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
59
+ "KNeighborsClassifier-score-n_jobs_check" : "uses daal4py for cpu in onedal" ,
60
+ "KNeighborsClassifier-predict-n_jobs_check" : "uses daal4py for cpu in onedal" ,
61
+ "KNeighborsClassifier-predict_proba-n_jobs_check" : "uses daal4py for cpu in onedal" ,
62
+ "KNeighborsClassifier-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
63
+ "KNeighborsRegressor-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
64
+ "KNeighborsRegressor-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
65
+ "KNeighborsRegressor-score-n_jobs_check" : "uses daal4py for cpu in onedal" ,
66
+ "KNeighborsRegressor-predict-n_jobs_check" : "uses daal4py for cpu in onedal" ,
67
+ "KNeighborsRegressor-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
68
+ "NearestNeighbors-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
69
+ "NearestNeighbors-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
70
+ "NearestNeighbors-radius_neighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
71
+ "NearestNeighbors-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
72
+ "NearestNeighbors-radius_neighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
73
+ "LocalOutlierFactor-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
74
+ "LocalOutlierFactor-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
75
+ "LocalOutlierFactor-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
76
+ "KNeighborsClassifier(algorithm='brute')-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
77
+ "KNeighborsClassifier(algorithm='brute')-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
78
+ "KNeighborsClassifier(algorithm='brute')-score-n_jobs_check" : "uses daal4py for cpu in onedal" ,
79
+ "KNeighborsClassifier(algorithm='brute')-predict-n_jobs_check" : "uses daal4py for cpu in onedal" ,
80
+ "KNeighborsClassifier(algorithm='brute')-predict_proba-n_jobs_check" : "uses daal4py for cpu in onedal" ,
81
+ "KNeighborsClassifier(algorithm='brute')-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
82
+ "KNeighborsRegressor(algorithm='brute')-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
83
+ "KNeighborsRegressor(algorithm='brute')-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
84
+ "KNeighborsRegressor(algorithm='brute')-score-n_jobs_check" : "uses daal4py for cpu in onedal" ,
85
+ "KNeighborsRegressor(algorithm='brute')-predict-n_jobs_check" : "uses daal4py for cpu in onedal" ,
86
+ "KNeighborsRegressor(algorithm='brute')-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
87
+ "NearestNeighbors(algorithm='brute')-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
88
+ "NearestNeighbors(algorithm='brute')-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
89
+ "NearestNeighbors(algorithm='brute')-radius_neighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
90
+ "NearestNeighbors(algorithm='brute')-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
91
+ "NearestNeighbors(algorithm='brute')-radius_neighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
92
+ "LocalOutlierFactor(novelty=True)-fit-n_jobs_check" : "uses daal4py for cpu in onedal" ,
93
+ "LocalOutlierFactor(novelty=True)-kneighbors-n_jobs_check" : "uses daal4py for cpu in onedal" ,
94
+ "LocalOutlierFactor(novelty=True)-kneighbors_graph-n_jobs_check" : "uses daal4py for cpu in onedal" ,
95
+ "LogisticRegression(solver='newton-cg')-score-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
96
+ "LogisticRegression(solver='newton-cg')-fit-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
97
+ "LogisticRegression(solver='newton-cg')-predict-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
98
+ "LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
99
+ "LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check" : "uses daal4py for cpu in sklearnex" ,
100
+ }
53
101
54
102
55
103
def test_target_offload_ban ():
@@ -78,29 +126,52 @@ def test_target_offload_ban():
78
126
assert output == "" , f"sklearn versioning is occuring in: \n { output } "
79
127
80
128
129
+ def _fullpath (path ):
130
+ return os .path .realpath (os .path .expanduser (path ))
131
+
132
+
81
133
_TRACE_ALLOW_DICT = {
82
- i : os .path .dirname (importlib .util .find_spec (i ).origin )
134
+ i : _fullpath ( os .path .dirname (importlib .util .find_spec (i ).origin ) )
83
135
for i in ["sklearn" , "sklearnex" , "onedal" , "daal4py" ]
84
136
}
85
137
86
138
87
139
def _whitelist_to_blacklist ():
88
- """block all standard library, builting or site packages which are not
140
+ """block all standard library, built-in or site packages which are not
89
141
related to sklearn, daal4py, onedal or sklearnex"""
90
142
143
+ def _commonpath (inp ):
144
+ # ValueError generated by os.path.commonpath when it is on a separate drive
145
+ try :
146
+ return os .path .commonpath (inp )
147
+ except ValueError :
148
+ return ""
149
+
91
150
blacklist = []
92
151
for path in sys .path :
152
+ fpath = _fullpath (path )
93
153
try :
94
- if any ([path in i for i in _TRACE_ALLOW_DICT .values ()]):
95
- blacklist += [
96
- f .path
97
- for f in os .scandir (path )
98
- if f .name not in _TRACE_ALLOW_DICT .keys ()
99
- ]
100
- else :
101
- blacklist += [path ]
154
+ # if candidate path is a parent directory to any directory in the whitelist
155
+ if any (
156
+ [_commonpath ([i , fpath ]) == fpath for i in _TRACE_ALLOW_DICT .values ()]
157
+ ):
158
+ # find all sub-paths which are not in the whitelist and block them
159
+ # they should not have a common path that is either the whitelist path
160
+ # or the sub-path (meaning one is a parent directory of the either)
161
+ for f in os .scandir (fpath ):
162
+ temppath = _fullpath (f .path )
163
+ if all (
164
+ [
165
+ _commonpath ([i , temppath ]) not in [i , temppath ]
166
+ for i in _TRACE_ALLOW_DICT .values ()
167
+ ]
168
+ ):
169
+ blacklist += [temppath ]
170
+ # add path to blacklist if not a sub path of anything in the whitelist
171
+ elif all ([_commonpath ([i , fpath ]) != i for i in _TRACE_ALLOW_DICT .values ()]):
172
+ blacklist += [fpath ]
102
173
except FileNotFoundError :
103
- blacklist += [path ]
174
+ blacklist += [fpath ]
104
175
return blacklist
105
176
106
177
@@ -152,7 +223,7 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
152
223
153
224
# initialize tracer to have a more verbose module naming
154
225
# this impacts ignoremods, but it is not used.
155
- monkeypatch .setattr (trace , "_modname" , lambda x : x )
226
+ monkeypatch .setattr (trace , "_modname" , _fullpath )
156
227
tracer = trace .Trace (
157
228
count = 0 ,
158
229
trace = 1 ,
@@ -197,23 +268,28 @@ def call_validate_data(text, estimator, method):
197
268
pytest .skip ("onedal backend not used in this function" )
198
269
199
270
validate_data = "validate_data" if sklearn_check_version ("1.6" ) else "_validate_data"
200
- try :
201
- assert (
202
- validfuncs .count (validate_data ) == 1
203
- ), f"sklearn's { validate_data } should be called"
204
- assert (
205
- validfuncs .count ("_check_feature_names" ) == 1
206
- ), "estimator should check feature names in validate_data"
207
- except AssertionError :
208
- if "-" .join ([estimator , method , "call_validate_data" ]) in _DESIGN_RULE_VIOLATIONS :
209
- pytest .xfail ("Allowed violation of design rules" )
210
- else :
211
- raise
271
+
272
+ assert (
273
+ validfuncs .count (validate_data ) == 1
274
+ ), f"sklearn's { validate_data } should be called"
275
+ assert (
276
+ validfuncs .count ("_check_feature_names" ) == 1
277
+ ), "estimator should check feature names in validate_data"
212
278
213
279
214
280
def n_jobs_check (text , estimator , method ):
215
281
"""verify the n_jobs is being set if '_get_backend' or 'to_table' is called"""
216
- count = max ([text [0 ].count (name ) for name in ["to_table" , "_get_backend" ]])
282
+ # remove the _get_backend function from sklearnex from considered _get_backend
283
+ count = max (
284
+ text [0 ].count ("to_table" ),
285
+ len (
286
+ [
287
+ i
288
+ for i in range (len (text [0 ]))
289
+ if text [0 ][i ] == "_get_backend" and "sklearnex" not in text [2 ][i ]
290
+ ]
291
+ ),
292
+ )
217
293
n_jobs_count = text [0 ].count ("n_jobs_wrapper" )
218
294
219
295
assert bool (count ) == bool (
@@ -235,4 +311,11 @@ def n_jobs_check(text, estimator, method):
235
311
)
236
312
def test_estimator (estimator , method , design_pattern , estimator_trace ):
237
313
# These tests only apply to sklearnex estimators
238
- design_pattern (estimator_trace , estimator , method )
314
+ try :
315
+ design_pattern (estimator_trace , estimator , method )
316
+ except AssertionError :
317
+ key = "-" .join ([estimator , method , design_pattern .__name__ ])
318
+ if key in _DESIGN_RULE_VIOLATIONS :
319
+ pytest .xfail (_DESIGN_RULE_VIOLATIONS [key ])
320
+ else :
321
+ raise
0 commit comments