Skip to content

Commit 00da9cc

Browse files
glemaitreraghavrv
authored andcommitted
[MRG+1] EHN add decimals parameter for export_graphviz (scikit-learn#8698)
* EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
1 parent b578371 commit 00da9cc

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

sklearn/tree/export.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# Li Li <aiki.nogard@gmail.com>
1212
# License: BSD 3 clause
1313

14+
from numbers import Integral
15+
1416
import numpy as np
1517
import warnings
1618

@@ -73,7 +75,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
7375
feature_names=None, class_names=None, label='all',
7476
filled=False, leaves_parallel=False, impurity=True,
7577
node_ids=False, proportion=False, rotate=False,
76-
rounded=False, special_characters=False):
78+
rounded=False, special_characters=False, precision=3):
7779
"""Export a decision tree in DOT format.
7880
7981
This function generates a GraphViz representation of the decision tree,
@@ -143,6 +145,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
143145
When set to ``False``, ignore special characters for PostScript
144146
compatibility.
145147
148+
precision : int, optional (default=3)
149+
Number of digits of precision for floating point in the values of
150+
impurity, threshold and value attributes of each node.
151+
146152
Returns
147153
-------
148154
dot_data : string
@@ -162,6 +168,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
162168
>>> clf = clf.fit(iris.data, iris.target)
163169
>>> tree.export_graphviz(clf,
164170
... out_file='tree.dot') # doctest: +SKIP
171+
165172
"""
166173

167174
def get_color(value):
@@ -226,7 +233,8 @@ def node_to_str(tree, node_id, criterion):
226233
characters[2])
227234
node_string += '%s %s %s%s' % (feature,
228235
characters[3],
229-
round(tree.threshold[node_id], 4),
236+
round(tree.threshold[node_id],
237+
precision),
230238
characters[4])
231239

232240
# Write impurity
@@ -237,7 +245,7 @@ def node_to_str(tree, node_id, criterion):
237245
criterion = "impurity"
238246
if labels:
239247
node_string += '%s = ' % criterion
240-
node_string += (str(round(tree.impurity[node_id], 4)) +
248+
node_string += (str(round(tree.impurity[node_id], precision)) +
241249
characters[4])
242250

243251
# Write node sample count
@@ -260,16 +268,16 @@ def node_to_str(tree, node_id, criterion):
260268
node_string += 'value = '
261269
if tree.n_classes[0] == 1:
262270
# Regression
263-
value_text = np.around(value, 4)
271+
value_text = np.around(value, precision)
264272
elif proportion:
265273
# Classification
266-
value_text = np.around(value, 2)
274+
value_text = np.around(value, precision)
267275
elif np.all(np.equal(np.mod(value, 1), 0)):
268276
# Classification without floating-point weights
269277
value_text = value.astype(int)
270278
else:
271279
# Classification with floating-point weights
272-
value_text = np.around(value, 4)
280+
value_text = np.around(value, precision)
273281
# Strip whitespace
274282
value_text = str(value_text.astype('S32')).replace("b'", "'")
275283
value_text = value_text.replace("' '", ", ").replace("'", "")
@@ -402,6 +410,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
402410
return_string = True
403411
out_file = six.StringIO()
404412

413+
if isinstance(precision, Integral):
414+
if precision < 0:
415+
raise ValueError("'precision' should be greater or equal to 0."
416+
" Got {} instead.".format(precision))
417+
else:
418+
raise ValueError("'precision' should be an integer. Got {}"
419+
" instead.".format(type(precision)))
420+
405421
# Check length of feature_names before getting into the tree node
406422
# Raise error if length of feature_names does not match
407423
# n_features_ in the decision_tree

sklearn/tree/tests/test_export.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
Testing for export functions of decision trees (sklearn.tree.export).
33
"""
44

5-
from re import finditer
5+
from re import finditer, search
66

7+
from numpy.random import RandomState
8+
9+
from sklearn.base import ClassifierMixin
710
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
811
from sklearn.ensemble import GradientBoostingClassifier
912
from sklearn.tree import export_graphviz
1013
from sklearn.externals.six import StringIO
11-
from sklearn.utils.testing import assert_in, assert_equal, assert_raises
12-
from sklearn.utils.testing import assert_raise_message
14+
from sklearn.utils.testing import (assert_in, assert_equal, assert_raises,
15+
assert_less_equal, assert_raises_regex,
16+
assert_raise_message)
1317
from sklearn.exceptions import NotFittedError
1418

1519
# toy sample
@@ -235,6 +239,13 @@ def test_graphviz_errors():
235239
out = StringIO()
236240
assert_raises(IndexError, export_graphviz, clf, out, class_names=[])
237241

242+
# Check precision error
243+
out = StringIO()
244+
assert_raises_regex(ValueError, "should be greater or equal",
245+
export_graphviz, clf, out, precision=-1)
246+
assert_raises_regex(ValueError, "should be an integer",
247+
export_graphviz, clf, out, precision="1")
248+
238249

239250
def test_friedman_mse_in_graphviz():
240251
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
@@ -249,3 +260,48 @@ def test_friedman_mse_in_graphviz():
249260

250261
for finding in finditer("\[.*?samples.*?\]", dot_data.getvalue()):
251262
assert_in("friedman_mse", finding.group())
263+
264+
265+
def test_precision():
266+
267+
rng_reg = RandomState(2)
268+
rng_clf = RandomState(8)
269+
for X, y, clf in zip(
270+
(rng_reg.random_sample((5, 2)),
271+
rng_clf.random_sample((1000, 4))),
272+
(rng_reg.random_sample((5, )),
273+
rng_clf.randint(2, size=(1000, ))),
274+
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
275+
max_depth=1),
276+
DecisionTreeClassifier(max_depth=1, random_state=0))):
277+
278+
clf.fit(X, y)
279+
for precision in (4, 3):
280+
dot_data = export_graphviz(clf, out_file=None, precision=precision,
281+
proportion=True)
282+
283+
# With the current random state, the impurity and the threshold
284+
# will have the number of precision set in the export_graphviz
285+
# function. We will check the number of precision with a strict
286+
# equality. The value reported will have only 2 precision and
287+
# therefore, only a less equal comparison will be done.
288+
289+
# check value
290+
for finding in finditer("value = \d+\.\d+", dot_data):
291+
assert_less_equal(
292+
len(search("\.\d+", finding.group()).group()),
293+
precision + 1)
294+
# check impurity
295+
if isinstance(clf, ClassifierMixin):
296+
pattern = "gini = \d+\.\d+"
297+
else:
298+
pattern = "friedman_mse = \d+\.\d+"
299+
300+
# check impurity
301+
for finding in finditer(pattern, dot_data):
302+
assert_equal(len(search("\.\d+", finding.group()).group()),
303+
precision + 1)
304+
# check threshold
305+
for finding in finditer("<= \d+\.\d+", dot_data):
306+
assert_equal(len(search("\.\d+", finding.group()).group()),
307+
precision + 1)

0 commit comments

Comments
 (0)