11
11
# Li Li <aiki.nogard@gmail.com>
12
12
# License: BSD 3 clause
13
13
14
+ from numbers import Integral
15
+
14
16
import numpy as np
15
17
import warnings
16
18
@@ -73,7 +75,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
73
75
feature_names = None , class_names = None , label = 'all' ,
74
76
filled = False , leaves_parallel = False , impurity = True ,
75
77
node_ids = False , proportion = False , rotate = False ,
76
- rounded = False , special_characters = False ):
78
+ rounded = False , special_characters = False , precision = 3 ):
77
79
"""Export a decision tree in DOT format.
78
80
79
81
This function generates a GraphViz representation of the decision tree,
@@ -143,6 +145,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
143
145
When set to ``False``, ignore special characters for PostScript
144
146
compatibility.
145
147
148
+ precision : int, optional (default=3)
149
+ Number of digits of precision for floating point in the values of
150
+ impurity, threshold and value attributes of each node.
151
+
146
152
Returns
147
153
-------
148
154
dot_data : string
@@ -162,6 +168,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
162
168
>>> clf = clf.fit(iris.data, iris.target)
163
169
>>> tree.export_graphviz(clf,
164
170
... out_file='tree.dot') # doctest: +SKIP
171
+
165
172
"""
166
173
167
174
def get_color (value ):
@@ -226,7 +233,8 @@ def node_to_str(tree, node_id, criterion):
226
233
characters [2 ])
227
234
node_string += '%s %s %s%s' % (feature ,
228
235
characters [3 ],
229
- round (tree .threshold [node_id ], 4 ),
236
+ round (tree .threshold [node_id ],
237
+ precision ),
230
238
characters [4 ])
231
239
232
240
# Write impurity
@@ -237,7 +245,7 @@ def node_to_str(tree, node_id, criterion):
237
245
criterion = "impurity"
238
246
if labels :
239
247
node_string += '%s = ' % criterion
240
- node_string += (str (round (tree .impurity [node_id ], 4 )) +
248
+ node_string += (str (round (tree .impurity [node_id ], precision )) +
241
249
characters [4 ])
242
250
243
251
# Write node sample count
@@ -260,16 +268,16 @@ def node_to_str(tree, node_id, criterion):
260
268
node_string += 'value = '
261
269
if tree .n_classes [0 ] == 1 :
262
270
# Regression
263
- value_text = np .around (value , 4 )
271
+ value_text = np .around (value , precision )
264
272
elif proportion :
265
273
# Classification
266
- value_text = np .around (value , 2 )
274
+ value_text = np .around (value , precision )
267
275
elif np .all (np .equal (np .mod (value , 1 ), 0 )):
268
276
# Classification without floating-point weights
269
277
value_text = value .astype (int )
270
278
else :
271
279
# Classification with floating-point weights
272
- value_text = np .around (value , 4 )
280
+ value_text = np .around (value , precision )
273
281
# Strip whitespace
274
282
value_text = str (value_text .astype ('S32' )).replace ("b'" , "'" )
275
283
value_text = value_text .replace ("' '" , ", " ).replace ("'" , "" )
@@ -402,6 +410,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
402
410
return_string = True
403
411
out_file = six .StringIO ()
404
412
413
+ if isinstance (precision , Integral ):
414
+ if precision < 0 :
415
+ raise ValueError ("'precision' should be greater or equal to 0."
416
+ " Got {} instead." .format (precision ))
417
+ else :
418
+ raise ValueError ("'precision' should be an integer. Got {}"
419
+ " instead." .format (type (precision )))
420
+
405
421
# Check length of feature_names before getting into the tree node
406
422
# Raise error if length of feature_names does not match
407
423
# n_features_ in the decision_tree
0 commit comments