Skip to content

Commit

Permalink
[MRG+1] Chassifier chain example fix (scikit-learn#9408)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Kleczewski authored and TomDLT committed Jul 25, 2017
1 parent 11e7369 commit 511bbc7
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions examples/multioutput/plot_classifier_chain_yeast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
Example of using classifier chain on a multilabel dataset.
For this example we will use the `yeast
<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which
contains 2417 datapoints each with 103 features and 14 possible labels. Each
datapoint has at least one label. As a baseline we first train a logistic
regression classifier for each of the 14 labels. To evaluate the performance
of these classifiers we predict on a held-out test set and calculate the
:ref:`User Guide <jaccard_similarity_score>`.
<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which contains
2417 datapoints each with 103 features and 14 possible labels. Each
data point has at least one label. As a baseline we first train a logistic
regression classifier for each of the 14 labels. To evaluate the performance of
these classifiers we predict on a held-out test set and calculate the
:ref:`jaccard similarity score <jaccard_similarity_score>`.
Next we create 10 classifier chains. Each classifier chain contains a
logistic regression model for each of the 14 labels. The models in each
Expand Down Expand Up @@ -79,7 +79,7 @@
model_scores = [ovr_jaccard_score] + chain_jaccard_scores
model_scores.append(ensemble_jaccard_score)

model_names = ('Independent Models',
model_names = ('Independent',
'Chain 1',
'Chain 2',
'Chain 3',
Expand All @@ -90,21 +90,22 @@
'Chain 8',
'Chain 9',
'Chain 10',
'Ensemble Average')
'Ensemble')

y_pos = np.arange(len(model_names))
y_pos[1:] += 1
y_pos[-1] += 1
x_pos = np.arange(len(model_names))

# Plot the Jaccard similarity scores for the independent model, each of the
# chains, and the ensemble (note that the vertical axis on this plot does
# not begin at 0).

fig = plt.figure(figsize=(7, 4))
plt.title('Classifier Chain Ensemble')
plt.xticks(y_pos, model_names, rotation='vertical')
plt.ylabel('Jaccard Similarity Score')
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
fig, ax = plt.subplots(figsize=(7, 4))
ax.grid(True)
ax.set_title('Classifier Chain Ensemble Performance Comparison')
ax.set_xticks(x_pos)
ax.set_xticklabels(model_names, rotation='vertical')
ax.set_ylabel('Jaccard Similarity Score')
ax.set_ylim([min(model_scores) * .9, max(model_scores) * 1.1])
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
ax.bar(x_pos, model_scores, alpha=0.5, color=colors)
plt.tight_layout()
plt.show()

0 comments on commit 511bbc7

Please sign in to comment.