From 2beefbc4b331149a8bf0dba5726e017651b845dc Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 23 Apr 2017 09:34:08 +0800 Subject: [PATCH] [MRG] Improve the error message of export_graphviz if a not-fitted decision tree is provided (#8776) --- sklearn/tree/export.py | 2 ++ sklearn/tree/tests/test_export.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 43e8aa11b9611..db89ae25d9721 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -14,6 +14,7 @@ import warnings from ..externals import six +from ..utils.validation import check_is_fitted from . import _criterion from . import _tree @@ -377,6 +378,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): # Add edge to parent out_file.write('%d -> %d ;\n' % (parent, node_id)) + check_is_fitted(decision_tree, 'tree_') own_file = False return_string = False try: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 1379a7703f31f..89d9cd7370ce0 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -9,6 +9,7 @@ from sklearn.tree import export_graphviz from sklearn.externals.six import StringIO from sklearn.utils.testing import assert_in, assert_equal, assert_raises +from sklearn.exceptions import NotFittedError # toy sample X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] @@ -210,6 +211,11 @@ def test_graphviz_toy(): def test_graphviz_errors(): # Check for errors of export_graphviz clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2) + + # Check not-fitted decision tree error + out = StringIO() + assert_raises(NotFittedError, export_graphviz, clf, out) + clf.fit(X, y) # Check feature_names error