Skip to content

Commit 2f1c978

Browse files
aikinogardraghavrv
authored andcommitted
[MRG+1] check length of feature_names in export_graphviz (scikit-learn#8477) (scikit-learn#8512)
* scikit-learn#8477 check length of feature_names in export_graphviz - raise ValueError if len(feature_names) > tree.n_features - add unit test for len(feature_names) > tree.n_features - change the comment of existing unit test for len(feature_names) < tree.n_features * fix error and warning - include length of feature_names and number of features in tree in the error and warning message. - raise error for too few feature_names - for too much feature_names, will use the first n_features. raise an warning for users - use assert_raise_message and assert_warns_message in test to check message. * move the error and warning from node_to_str to export_graphviz so it will fail early for wrong length of feature_names * raise error if length of feature_names does not match number of features in the decision tree * fix pep8 * remove unused assert_warns_message import in test_export.py * add bug fix in doc/whats_new.rst * fix the english in doc/whats_new.rst * fix the format and english in sklearn/tree/export.py * fix contributor format in doc/whats_new.rst * fix english, use bracket and avoid \ in error message * fix pep8 * fix pep8
1 parent 28a0d56 commit 2f1c978

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ Bug fixes
276276
- Fixed a bug in :class:`svm.OneClassSVM` where it returned floats instead of
277277
integer classes. :issue:`8676` by :user:`Vathsala Achar <VathsalaAchar>`.
278278

279+
- Fixed a bug where :func:`sklearn.tree.export_graphviz` raised an error
280+
when the length of features_names does not match n_features in the decision
281+
tree.
282+
:issue:`8512` by :user:`Li Li <aikinogard>`.
283+
279284
API changes summary
280285
-------------------
281286

sklearn/tree/export.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# Noel Dawe <noel@dawe.me>
99
# Satrajit Gosh <satrajit.ghosh@gmail.com>
1010
# Trevor Stephens <trev.stephens@gmail.com>
11+
# Li Li <aiki.nogard@gmail.com>
1112
# License: BSD 3 clause
1213

1314
import numpy as np
@@ -172,7 +173,8 @@ def get_color(value):
172173
if len(sorted_values) == 1:
173174
alpha = 0
174175
else:
175-
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
176+
alpha = int(np.round(255 * (sorted_values[0] -
177+
sorted_values[1]) /
176178
(1 - sorted_values[1]), 0))
177179
else:
178180
# Regression tree or multi-output
@@ -330,7 +332,8 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
330332
# Find max and min impurities for multi-output
331333
colors['bounds'] = (np.min(-tree.impurity),
332334
np.max(-tree.impurity))
333-
elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
335+
elif (tree.n_classes[0] == 1 and
336+
len(np.unique(tree.value)) != 1):
334337
# Find max and min values in leaf nodes for regression
335338
colors['bounds'] = (np.min(tree.value),
336339
np.max(tree.value))
@@ -399,6 +402,16 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
399402
return_string = True
400403
out_file = six.StringIO()
401404

405+
# Check length of feature_names before getting into the tree node
406+
# Raise error if length of feature_names does not match
407+
# n_features_ in the decision_tree
408+
if feature_names is not None:
409+
if len(feature_names) != decision_tree.n_features_:
410+
raise ValueError("Length of feature_names, %d "
411+
"does not match number of features, %d"
412+
% (len(feature_names),
413+
decision_tree.n_features_))
414+
402415
# The depth of each node for plotting with 'leaf' option
403416
ranks = {'leaves': []}
404417
# The colors to render each node with

sklearn/tree/tests/test_export.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.tree import export_graphviz
1010
from sklearn.externals.six import StringIO
1111
from sklearn.utils.testing import assert_in, assert_equal, assert_raises
12+
from sklearn.utils.testing import assert_raise_message
1213
from sklearn.exceptions import NotFittedError
1314

1415
# toy sample
@@ -218,9 +219,17 @@ def test_graphviz_errors():
218219

219220
clf.fit(X, y)
220221

221-
# Check feature_names error
222-
out = StringIO()
223-
assert_raises(IndexError, export_graphviz, clf, out, feature_names=[])
222+
# Check if it errors when length of feature_names
223+
# mismatches with number of features
224+
message = ("Length of feature_names, "
225+
"1 does not match number of features, 2")
226+
assert_raise_message(ValueError, message, export_graphviz, clf, None,
227+
feature_names=["a"])
228+
229+
message = ("Length of feature_names, "
230+
"3 does not match number of features, 2")
231+
assert_raise_message(ValueError, message, export_graphviz, clf, None,
232+
feature_names=["a", "b", "c"])
224233

225234
# Check class_names error
226235
out = StringIO()

0 commit comments

Comments
 (0)