Skip to content

Commit

Permalink
Merge pull request scikit-learn#3078 from ndawe/examples
Browse files Browse the repository at this point in the history
[MRG] plot_adaboost_multiclass.py: handle case where boosting terminated early
  • Loading branch information
glouppe committed Apr 17, 2014
2 parents a32eb88 + f260259 commit 64f3026
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
35 changes: 23 additions & 12 deletions examples/ensemble/plot_adaboost_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from sklearn.datasets import make_gaussian_quantiles
from sklearn.ensemble import AdaBoostClassifier
from sklearn.externals.six.moves import xrange
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier

Expand Down Expand Up @@ -72,37 +71,49 @@
discrete_test_errors.append(
1. - accuracy_score(discrete_train_predict, y_test))

n_trees = xrange(1, len(bdt_discrete) + 1)
n_trees_discrete = len(bdt_discrete)
n_trees_real = len(bdt_real)

# Boosting might terminate early, but the following arrays are always
# n_estimators long. We crop them to the actual number of trees here:
discrete_estimator_errors = bdt_discrete.estimator_errors_[:n_trees_discrete]
real_estimator_errors = bdt_real.estimator_errors_[:n_trees_real]
discrete_estimator_weights = bdt_discrete.estimator_weights_[:n_trees_discrete]

plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.plot(n_trees, discrete_test_errors, c='black', label='SAMME')
plt.plot(n_trees, real_test_errors, c='black',
linestyle='dashed', label='SAMME.R')
plt.plot(range(1, n_trees_discrete + 1),
discrete_test_errors, c='black', label='SAMME')
plt.plot(range(1, n_trees_real + 1),
real_test_errors, c='black',
linestyle='dashed', label='SAMME.R')
plt.legend()
plt.ylim(0.18, 0.62)
plt.ylabel('Test Error')
plt.xlabel('Number of Trees')

plt.subplot(132)
plt.plot(n_trees, bdt_discrete.estimator_errors_, "b", label='SAMME', alpha=.5)
plt.plot(n_trees, bdt_real.estimator_errors_, "r", label='SAMME.R', alpha=.5)
plt.plot(range(1, n_trees_discrete + 1), discrete_estimator_errors,
"b", label='SAMME', alpha=.5)
plt.plot(range(1, n_trees_real + 1), real_estimator_errors,
"r", label='SAMME.R', alpha=.5)
plt.legend()
plt.ylabel('Error')
plt.xlabel('Number of Trees')
plt.ylim((.2,
max(bdt_real.estimator_errors_.max(),
bdt_discrete.estimator_errors_.max()) * 1.2))
max(real_estimator_errors.max(),
discrete_estimator_errors.max()) * 1.2))
plt.xlim((-20, len(bdt_discrete) + 20))

plt.subplot(133)
plt.plot(n_trees, bdt_discrete.estimator_weights_, "b", label='SAMME')
plt.plot(range(1, n_trees_discrete + 1), discrete_estimator_weights,
"b", label='SAMME')
plt.legend()
plt.ylabel('Weight')
plt.xlabel('Number of Trees')
plt.ylim((0, bdt_discrete.estimator_weights_.max() * 1.2))
plt.xlim((-20, len(bdt_discrete) + 20))
plt.ylim((0, discrete_estimator_weights.max() * 1.2))
plt.xlim((-20, n_trees_discrete + 20))

# prevent overlapping y-axis labels
plt.subplots_adjust(wspace=0.25)
Expand Down
4 changes: 4 additions & 0 deletions examples/ensemble/plot_adaboost_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"""
print(__doc__)

# Author: Noel Dawe <[email protected]>
#
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt

Expand Down
18 changes: 11 additions & 7 deletions examples/ensemble/plot_adaboost_twoclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
"""
print(__doc__)

# Author: Noel Dawe <[email protected]>
#
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -65,8 +69,8 @@
for i, n, c in zip(range(2), class_names, plot_colors):
idx = np.where(y == i)
plt.scatter(X[idx, 0], X[idx, 1],
c=c, cmap=plt.cm.Paired,
label="Class %s" % n)
c=c, cmap=plt.cm.Paired,
label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
Expand All @@ -78,11 +82,11 @@
plt.subplot(122)
for i, n, c in zip(range(2), class_names, plot_colors):
plt.hist(twoclass_output[y == i],
bins=10,
range=plot_range,
facecolor=c,
label='Class %s' % n,
alpha=.5)
bins=10,
range=plot_range,
facecolor=c,
label='Class %s' % n,
alpha=.5)
x1, x2, y1, y2 = plt.axis()
plt.axis((x1, x2, y1, y2 * 1.2))
plt.legend(loc='upper right')
Expand Down

0 comments on commit 64f3026

Please sign in to comment.