Skip to content

Commit

Permalink
[MRG] Improve the error message of export_graphviz if a not-fitted de…
Browse files Browse the repository at this point in the history
…cision tree is provided (scikit-learn#8776)
  • Loading branch information
qinhanmin2014 authored and jnothman committed Apr 23, 2017
1 parent 3a3637c commit 2beefbc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sklearn/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

from ..externals import six
from ..utils.validation import check_is_fitted

from . import _criterion
from . import _tree
Expand Down Expand Up @@ -377,6 +378,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
# Add edge to parent
out_file.write('%d -> %d ;\n' % (parent, node_id))

check_is_fitted(decision_tree, 'tree_')
own_file = False
return_string = False
try:
Expand Down
6 changes: 6 additions & 0 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from sklearn.utils.testing import assert_in, assert_equal, assert_raises
from sklearn.exceptions import NotFittedError

# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
Expand Down Expand Up @@ -210,6 +211,11 @@ def test_graphviz_toy():
def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)

# Check not-fitted decision tree error
out = StringIO()
assert_raises(NotFittedError, export_graphviz, clf, out)

clf.fit(X, y)

# Check feature_names error
Expand Down

0 comments on commit 2beefbc

Please sign in to comment.