Skip to content

Commit 2beefbc

Browse files
qinhanmin2014jnothman
authored andcommitted
[MRG] Improve the error message of export_graphviz if a not-fitted decision tree is provided (scikit-learn#8776)
1 parent 3a3637c commit 2beefbc

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sklearn/tree/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515

1616
from ..externals import six
17+
from ..utils.validation import check_is_fitted
1718

1819
from . import _criterion
1920
from . import _tree
@@ -377,6 +378,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
377378
# Add edge to parent
378379
out_file.write('%d -> %d ;\n' % (parent, node_id))
379380

381+
check_is_fitted(decision_tree, 'tree_')
380382
own_file = False
381383
return_string = False
382384
try:

sklearn/tree/tests/test_export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.tree import export_graphviz
1010
from sklearn.externals.six import StringIO
1111
from sklearn.utils.testing import assert_in, assert_equal, assert_raises
12+
from sklearn.exceptions import NotFittedError
1213

1314
# toy sample
1415
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
@@ -210,6 +211,11 @@ def test_graphviz_toy():
210211
def test_graphviz_errors():
211212
# Check for errors of export_graphviz
212213
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
214+
215+
# Check not-fitted decision tree error
216+
out = StringIO()
217+
assert_raises(NotFittedError, export_graphviz, clf, out)
218+
213219
clf.fit(X, y)
214220

215221
# Check feature_names error

0 commit comments

Comments
 (0)