Skip to content

Commit aa69d85

Browse files
committed
FIX scikit-learn#401: update tutorial doctests to reflect recent changes and add them to
1 parent c06ebe7 commit aa69d85

File tree

4 files changed

+91
-87
lines changed

4 files changed

+91
-87
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ test-code: in
3232
$(NOSETESTS) -s sklearn
3333
test-doc:
3434
$(NOSETESTS) -s --with-doctest --doctest-tests --doctest-extension=rst \
35-
--doctest-fixtures=_fixture doc/modules/
35+
--doctest-fixtures=_fixture doc/ doc/modules/
3636

3737
test-coverage:
3838
$(NOSETESTS) -s --with-coverage --cover-html --cover-html-dir=coverage \

doc/modules/cross_validation.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _cross_validation:
2+
13
================
24
Cross-Validation
35
================

doc/tutorial.rst

+87-85
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ Getting started: an introduction to machine learning with scikit-learn
1212
Machine learning: the problem setting
1313
---------------------------------------
1414

15-
In general, a learning problem considers a set of n *samples* of data and
16-
try to predict properties of unknown data. If each sample is more than a
17-
single number, and for instance a multi-dimensional entry (aka
18-
*multivariate* data), is it said to have several attributes, or
19-
*features*.
15+
In general, a learning problem considers a set of n **samples** of
16+
data and try to predict properties of unknown data. If each sample is
17+
more than a single number, and for instance a multi-dimensional entry
18+
(aka **multivariate** data), is it said to have several attributes,
19+
or **features**.
2020

2121
We can separate learning problems in a few large categories:
2222

@@ -46,12 +46,12 @@ We can separate learning problems in a few large categories:
4646

4747
.. topic:: Training set and testing set
4848

49-
Machine learning is about learning some properties of a data set and
50-
applying them to new data. This is why a common practice in machine
51-
learning to evaluate an algorithm is to split the data at hand in two
52-
sets, one that we call a *training set* on which we learn data
53-
properties, and one that we call a *testing set*, on which we test
54-
these properties.
49+
Machine learning is about learning some properties of a data set
50+
and applying them to new data. This is why a common practice in
51+
machine learning to evaluate an algorithm is to split the data
52+
at hand in two sets, one that we call a **training set** on which
53+
we learn data properties, and one that we call a **testing set**,
54+
on which we test these properties.
5555

5656

5757
Loading an example dataset
@@ -63,65 +63,57 @@ Loading an example dataset
6363
datasets for classification and the `boston house prices dataset
6464
<http://archive.ics.uci.edu/ml/datasets/Housing>`_ for regression.::
6565

66-
>>> from sklearn import datasets
67-
>>> iris = datasets.load_iris()
68-
>>> digits = datasets.load_digits()
66+
>>> from sklearn import datasets
67+
>>> iris = datasets.load_iris()
68+
>>> digits = datasets.load_digits()
6969

7070
A dataset is a dictionary-like object that holds all the data and some
71-
metadata about the data. This data is stored in the `.data` member, which
72-
is a `n_samples, n_features` array. In the case of supervised problem,
73-
explanatory variables are stored in the `.target` member. More details on
74-
the different datasets can be found in the
75-
:ref:`dedicated section <datasets>`.
71+
metadata about the data. This data is stored in the ``.data`` member,
72+
which is a ``n_samples, n_features`` array. In the case of supervised
73+
problem, explanatory variables are stored in the ``.target`` member. More
74+
details on the different datasets can be found in the :ref:`dedicated
75+
section <datasets>`.
7676

77-
For instance, in the case of the digits dataset, `digits.data` gives
77+
For instance, in the case of the digits dataset, ``digits.data`` gives
7878
access to the features that can be used to classify the digits samples::
7979

80-
>>> print digits.data
81-
[[ 0. 0. 5. ..., 0. 0. 0.]
82-
[ 0. 0. 0. ..., 10. 0. 0.]
83-
[ 0. 0. 0. ..., 16. 9. 0.]
84-
...,
85-
[ 0. 0. 1. ..., 6. 0. 0.]
86-
[ 0. 0. 2. ..., 12. 0. 0.]
87-
[ 0. 0. 10. ..., 12. 1. 0.]]
80+
>>> print digits.data
81+
[[ 0. 0. 5. ..., 0. 0. 0.]
82+
[ 0. 0. 0. ..., 10. 0. 0.]
83+
[ 0. 0. 0. ..., 16. 9. 0.]
84+
...,
85+
[ 0. 0. 1. ..., 6. 0. 0.]
86+
[ 0. 0. 2. ..., 12. 0. 0.]
87+
[ 0. 0. 10. ..., 12. 1. 0.]]
8888

8989
and `digits.target` gives the ground truth for the digit dataset, that
9090
is the number corresponding to each digit image that we are trying to
91-
learn:
91+
learn::
9292

93-
>>> digits.target
94-
array([0, 1, 2, ..., 8, 9, 8])
93+
>>> digits.target
94+
array([0, 1, 2, ..., 8, 9, 8])
9595

9696
.. topic:: Shape of the data arrays
9797

9898
The data is always a 2D array, `n_samples, n_features`, although
9999
the original data may have had a different shape. In the case of the
100100
digits, each original sample is an image of shape `8, 8` and can be
101-
accessed using:
101+
accessed using::
102102

103-
>>> digits.images[0]
104-
array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
105-
[ 0., 0., 13., 15., 10., 15., 5., 0.],
106-
[ 0., 3., 15., 2., 0., 11., 8., 0.],
107-
[ 0., 4., 12., 0., 0., 8., 8., 0.],
108-
[ 0., 5., 8., 0., 0., 9., 8., 0.],
109-
[ 0., 4., 11., 0., 1., 12., 7., 0.],
110-
[ 0., 2., 14., 5., 10., 12., 0., 0.],
111-
[ 0., 0., 6., 13., 10., 0., 0., 0.]])
103+
>>> digits.images[0]
104+
array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
105+
[ 0., 0., 13., 15., 10., 15., 5., 0.],
106+
[ 0., 3., 15., 2., 0., 11., 8., 0.],
107+
[ 0., 4., 12., 0., 0., 8., 8., 0.],
108+
[ 0., 5., 8., 0., 0., 9., 8., 0.],
109+
[ 0., 4., 11., 0., 1., 12., 7., 0.],
110+
[ 0., 2., 14., 5., 10., 12., 0., 0.],
111+
[ 0., 0., 6., 13., 10., 0., 0., 0.]])
112112

113-
The :ref:`simple example on this dataset <example_plot_digits_classification.py>`
114-
illustrates how starting from the original problem one can shape the
115-
data for consumption in the `scikit-learn`.
116-
117-
118-
``sklearn`` also offers the possibility to reuse external datasets coming
119-
from the http://mlcomp.org online service that provides a repository of public
120-
datasets for various tasks (binary & multi label classification, regression,
121-
document classification, ...) along with a runtime environment to compare
122-
program performance on those datasets. Please refer to the following example for
123-
for instructions on the ``mlcomp`` dataset loader:
124-
:ref:`example mlcomp sparse document classification <example_mlcomp_sparse_document_classification.py>`.
113+
The :ref:`simple example on this dataset
114+
<example_plot_digits_classification.py>` illustrates how starting
115+
from the original problem one can shape the data for consumption in
116+
the `scikit-learn`.
125117

126118

127119
Learning and Predicting
@@ -132,35 +124,42 @@ hand-written digit from an image. We are given samples of each of the 10
132124
possible classes on which we *fit* an `estimator` to be able to *predict*
133125
the labels corresponding to new data.
134126

135-
In `scikit-learn`, an *estimator* is just a plain Python class that
127+
In `scikit-learn`, an **estimator** is just a plain Python class that
136128
implements the methods `fit(X, Y)` and `predict(T)`.
137129

138130
An example of estimator is the class ``sklearn.svm.SVC`` that
139131
implements `Support Vector Classification
140132
<http://en.wikipedia.org/wiki/Support_vector_machine>`_. The
141133
constructor of an estimator takes as arguments the parameters of the
142134
model, but for the time being, we will consider the estimator as a black
143-
box and not worry about these:
135+
box::
136+
137+
>>> from sklearn import svm
138+
>>> clf = svm.SVC(gamma=0.001)
144139

145-
>>> from sklearn import svm
146-
>>> clf = svm.SVC()
140+
.. topic:: Choosing the parameters of the model
141+
142+
In this example we set the value of ``gamma`` manually. It is possible
143+
to automatically find good values for the parameters by using tools
144+
such as :ref:`grid search <grid_search>` and :ref:`cross validation
145+
<cross_validation>`.
147146

148147
We call our estimator instance `clf` as it is a classifier. It now must
149148
be fitted to the model, that is, it must `learn` from the model. This is
150149
done by passing our training set to the ``fit`` method. As a training
151150
set, let us use all the images of our dataset apart from the last
152-
one:
151+
one::
153152

154-
>>> clf.fit(digits.data[:-1], digits.target[:-1])
155-
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, tol=0.001,
156-
cache_size=100.0, shrinking=True, gamma=0.000556792873051)
153+
>>> clf.fit(digits.data[:-1], digits.target[:-1])
154+
SVC(C=1.0, coef0=0.0, degree=3, gamma=0.001, kernel='rbf', probability=False,
155+
shrinking=True, tol=0.001)
157156

158157
Now you can predict new values, in particular, we can ask to the
159158
classifier what is the digit of our last image in the `digits` dataset,
160-
which we have not used to train the classifier:
159+
which we have not used to train the classifier::
161160

162-
>>> clf.predict(digits.data[-1])
163-
array([ 8.])
161+
>>> clf.predict(digits.data[-1])
162+
array([ 8.])
164163

165164
The corresponding image is the following:
166165

@@ -175,32 +174,35 @@ A complete example of this classification problem is available as an
175174
example that you can run and study:
176175
:ref:`example_plot_digits_classification.py`.
177176

177+
178178
Model persistence
179179
-----------------
180180

181181
It is possible to save a model in the scikit by using Python's built-in
182-
persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_.
183-
184-
>>> from sklearn import svm
185-
>>> from sklearn import datasets
186-
>>> clf = svm.SVC()
187-
>>> iris = datasets.load_iris()
188-
>>> X, y = iris.data, iris.target
189-
>>> clf.fit(X, y)
190-
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, tol=0.001,
191-
cache_size=100.0, shrinking=True, gamma=0.00666666666667)
192-
>>> import pickle
193-
>>> s = pickle.dumps(clf)
194-
>>> clf2 = pickle.loads(s)
195-
>>> clf2.predict(X[0])
196-
array([ 0.])
197-
>>> y[0]
198-
0
182+
persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_::
183+
184+
>>> from sklearn import svm
185+
>>> from sklearn import datasets
186+
>>> clf = svm.SVC()
187+
>>> iris = datasets.load_iris()
188+
>>> X, y = iris.data, iris.target
189+
>>> clf.fit(X, y)
190+
SVC(C=1.0, coef0=0.0, degree=3, gamma=0.25, kernel='rbf', probability=False,
191+
shrinking=True, tol=0.001)
192+
193+
>>> import pickle
194+
>>> s = pickle.dumps(clf)
195+
>>> clf2 = pickle.loads(s)
196+
>>> clf2.predict(X[0])
197+
array([ 0.])
198+
>>> y[0]
199+
0
199200

200201
In the specific case of the scikit, it may be more interesting to use
201-
joblib's replacement of pickle, which is more efficient on big data, but
202-
can only pickle to the disk and not to a string:
202+
joblib's replacement of pickle (``joblib.dump`` & ``joblib.load``),
203+
which is more efficient on big data, but can only pickle to the disk
204+
and not to a string::
203205

204-
>>> from sklearn.externals import joblib
205-
>>> joblib.dump(clf, 'filename.pkl') # doctest: +SKIP
206+
>>> from sklearn.externals import joblib
207+
>>> joblib.dump(clf, 'filename.pkl') # doctest: +SKIP
206208

examples/plot_digits_classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
data = digits.images.reshape((n_samples, -1))
4141

4242
# Create a classifier: a support vector classifier
43-
classifier = svm.SVC()
43+
classifier = svm.SVC(gamma=0.001)
4444

4545
# We learn the digits on the first half of the digits
4646
classifier.fit(data[:n_samples/2], digits.target[:n_samples/2])

0 commit comments

Comments
 (0)