Skip to content

Commit

Permalink
Finish up mlp_refactoring and squash previous commits
Browse files Browse the repository at this point in the history
  • Loading branch information
glennq committed Oct 23, 2015
1 parent bde2270 commit 917bacb
Show file tree
Hide file tree
Showing 12 changed files with 1,057 additions and 527 deletions.
32 changes: 19 additions & 13 deletions benchmarks/bench_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
===========================
Classifier train-time test-time error-rate
------------------------------------------------------------
MultilayerPerceptron 475.76s 1.31s 0.0201
Nystroem-SVM 218.38s 17.86s 0.0229
ExtraTrees 45.54s 0.52s 0.0288
RandomForest 44.79s 0.32s 0.0304
SampledRBF-SVM 265.64s 19.78s 0.0488
CART 21.13s 0.01s 0.1214
dummy 0.01s 0.01s 0.8973
MLP_adam 53.46s 0.11s 0.0224
Nystroem-SVM 112.97s 0.92s 0.0228
MultilayerPerceptron 24.33s 0.14s 0.0287
ExtraTrees 42.99s 0.57s 0.0294
RandomForest 42.70s 0.49s 0.0318
SampledRBF-SVM 135.81s 0.56s 0.0486
LinearRegression-SAG 16.67s 0.06s 0.0824
CART 20.69s 0.02s 0.1219
dummy 0.00s 0.01s 0.8973
"""
from __future__ import division, print_function

Expand Down Expand Up @@ -85,14 +87,18 @@ def load_data(dtype=np.float32, order='F'):
'CART': DecisionTreeClassifier(),
'ExtraTrees': ExtraTreesClassifier(n_estimators=100),
'RandomForest': RandomForestClassifier(n_estimators=100),
'Nystroem-SVM':
make_pipeline(Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'SampledRBF-SVM':
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4)
'Nystroem-SVM': make_pipeline(
Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'SampledRBF-SVM': make_pipeline(
RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4),
'MultilayerPerceptron': MLPClassifier(
hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
algorithm='sgd', learning_rate_init=0.5, momentum=0.9, verbose=1,
algorithm='sgd', learning_rate_init=0.2, momentum=0.9, verbose=1,
tol=1e-4, random_state=1),
'MLP-adam': MLPClassifier(
hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
algorithm='adam', learning_rate_init=0.001, verbose=1,
tol=1e-4, random_state=1)
}

Expand Down
209 changes: 115 additions & 94 deletions doc/modules/neural_networks_supervised.rst

Large diffs are not rendered by default.

21 changes: 10 additions & 11 deletions examples/neural_networks/plot_mlp_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
Varying regularization in Multi-layer Perceptron
================================================
A comparison of different regularization term 'alpha' values on synthetic
datasets. The plot shows that different alphas yield different decision
functions.
Alpha is a regularization term, or also known as penalty term, that combats
overfitting by constraining the weights' size. Increasing alpha may fix high
variance (a sign of overfitting) by encouraging smaller weights, resulting
in a decision function plot that may appear with lesser curvatures.
A comparison of different values for regularization parameter 'alpha' on
synthetic datasets. The plot shows that different alphas yield different
decision functions.
Alpha is a parameter for regularization term, aka penalty term, that combats
overfitting by constraining the size of the weights. Increasing alpha may fix
high variance (a sign of overfitting) by encouraging smaller weights, resulting
in a decision boundary plot that appears with lesser curvatures.
Similarly, decreasing alpha may fix high bias (a sign of underfitting) by
encouraging larger weights, potentially resulting in more curvatures in the
decision function plot.
encouraging larger weights, potentially resulting in a more complicated
decision boundery.
"""
print(__doc__)

Expand Down
46 changes: 0 additions & 46 deletions examples/neural_networks/plot_mlp_nonlinear.py

This file was deleted.

42 changes: 26 additions & 16 deletions examples/neural_networks/plot_mlp_training_curves.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
"""
==================================================
Compare SGD learning strategies for MLPClassifier
==================================================
========================================================
Compare Stochastic learning strategies for MLPClassifier
========================================================
This example visualizes some training loss curves for different SGD mini-batch
learning strategies. Because of time-constraints, we use several small
datasets, for which L-BFGS might be more suitable. The general trend shown in
these examples seems to carry over to larger datasets, however.
This example visualizes some training loss curves for different stochastic
learning strategies, including SGD and Adam. Because of time-constraints, we
use several small datasets, for which L-BFGS might be more suitable. The
general trend shown in these examples seems to carry over to larger datasets,
however.
"""

print(__doc__)
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn import datasets

# different learning rate schedules and momentum parameters
params = [{'learning_rate': 'constant', 'momentum': 0},
{'learning_rate': 'constant', 'momentum': .9, 'nesterovs_momentum': False},
{'learning_rate': 'constant', 'momentum': .9, 'nesterovs_momentum': True},
{'learning_rate': 'invscaling', 'momentum': 0},
{'learning_rate': 'invscaling', 'momentum': .9, 'nesterovs_momentum': True},
{'learning_rate': 'invscaling', 'momentum': .9, 'nesterovs_momentum': False}]
params = [{'algorithm': 'sgd', 'learning_rate': 'constant', 'momentum': 0,
'learning_rate_init': 0.2},
{'algorithm': 'sgd', 'learning_rate': 'constant', 'momentum': .9,
'nesterovs_momentum': False, 'learning_rate_init': 0.2},
{'algorithm': 'sgd', 'learning_rate': 'constant', 'momentum': .9,
'nesterovs_momentum': True, 'learning_rate_init': 0.2},
{'algorithm': 'sgd', 'learning_rate': 'invscaling', 'momentum': 0,
'learning_rate_init': 0.2},
{'algorithm': 'sgd', 'learning_rate': 'invscaling', 'momentum': .9,
'nesterovs_momentum': True, 'learning_rate_init': 0.2},
{'algorithm': 'sgd', 'learning_rate': 'invscaling', 'momentum': .9,
'nesterovs_momentum': False, 'learning_rate_init': 0.2},
{'algorithm': 'adam'}]

labels = ["constant learning-rate", "constant with momentum",
"constant with Nesterov's momentum",
"inv-scaling learning-rate", "inv-scaling with momentum",
"inv-scaling with Nesterov's momentum"]
"inv-scaling with Nesterov's momentum", "adam"]

plot_args = [{'c': 'red', 'linestyle': '-'},
{'c': 'green', 'linestyle': '-'},
{'c': 'blue', 'linestyle': '-'},
{'c': 'red', 'linestyle': '--'},
{'c': 'green', 'linestyle': '--'},
{'c': 'blue', 'linestyle': '--'}]
{'c': 'blue', 'linestyle': '--'},
{'c': 'black', 'linestyle': '-'}]


def plot_on_dataset(X, y, ax, name):
Expand All @@ -49,7 +59,7 @@ def plot_on_dataset(X, y, ax, name):

for label, param in zip(labels, params):
print("training: %s" % label)
mlp = MLPClassifier(verbose=0, algorithm='sgd', random_state=0,
mlp = MLPClassifier(verbose=0, random_state=0,
max_iter=max_iter, **param)
mlp.fit(X, y)
mlps.append(mlp)
Expand Down
21 changes: 11 additions & 10 deletions examples/neural_networks/plot_mnist_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
=====================================
Sometimes looking at the learned coefficients of a neural network can provide
inside into the learning behavior. For example if weights look unstructured,
maybe a weight was not used at all, or if very large coefficients exist, maybe
insight into the learning behavior. For example if weights look unstructured,
maybe some were not used at all, or if very large coefficients exist, maybe
regularization was too low or the learning rate too high.
This example shows how to plot some of the first layer weights in a
MLPClassifier trained on the MNIST dataset.
The input data consists of 28x28 pixel handwritten digits, leading to 784
features in the dataset. Therefore the first layer weight have the shape (784,
hidden_layer_sizes[0]). We can therefore visualize a single column of the
weight matrix as a 28x28 pixel image.
features in the dataset. Therefore the first layer weight matrix have the shape
(784, hidden_layer_sizes[0]). We can therefore visualize a single column of
the weight matrix as a 28x28 pixel image.
To make the example run faster, we use very few hidden units, and train only
for a very short time. Training longer would result in much smoother weights.
for a very short time. Training longer would result in weights with a much
smoother spatial appearance.
"""
print(__doc__)

Expand All @@ -31,8 +32,8 @@
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

#mlp = MLPClassifier(hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
# algorithm='sgd', verbose=10, tol=1e-4, random_state=1)
# mlp = MLPClassifier(hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
# algorithm='sgd', verbose=10, tol=1e-4, random_state=1)
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
algorithm='sgd', verbose=10, tol=1e-4, random_state=1,
learning_rate_init=.1)
Expand All @@ -45,8 +46,8 @@
# use global min / max to ensure all weights are shown on the same scale
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin, vmax=.5
* vmax)
ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
vmax=.5 * vmax)
ax.set_xticks(())
ax.set_yticks(())

Expand Down
50 changes: 30 additions & 20 deletions sklearn/neural_network/base.py → sklearn/neural_network/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def relu(X):


def softmax(X):
"""Compute the K-way softmax function inplace.
"""Compute the K-way softmax function inplace.
Parameters
----------
Expand All @@ -99,13 +99,17 @@ def softmax(X):
'relu': relu, 'softmax': softmax}


def logistic_derivative(Z):
"""Compute the derivative of the logistic function.
def inplace_logistic_derivative(Z):
"""Compute the derivative of the logistic function given output value
from logistic function
It exploits the fact that the derivative is a simple function of the output
value from logistic function
Parameters
----------
Z : {array-like, sparse matrix}, shape (n_samples, n_features)
The input data.
The input data which is output from logistic function
Returns
-------
Expand All @@ -115,13 +119,17 @@ def logistic_derivative(Z):
return Z * (1 - Z)


def tanh_derivative(Z):
"""Compute the derivative of the hyperbolic tan function.
def inplace_tanh_derivative(Z):
"""Compute the derivative of the hyperbolic tan function given output value
from hyperbolic tan
It exploits the fact that the derivative is a simple function of the output
value from hyperbolic tan
Parameters
----------
Z : {array-like, sparse matrix}, shape (n_samples, n_features)
The input data.
The input data which is output from hyperbolic tan function
Returns
-------
Expand All @@ -131,13 +139,14 @@ def tanh_derivative(Z):
return 1 - (Z ** 2)


def relu_derivative(Z):
"""Compute the derivative of the rectified linear unit function.
def inplace_relu_derivative(Z):
"""Compute the derivative of the rectified linear unit function given output
value from relu
Parameters
----------
Z : {array-like, sparse matrix}, shape (n_samples, n_features)
The input data.
The input data which is output from some relu
Returns
-------
Expand All @@ -147,8 +156,9 @@ def relu_derivative(Z):
return (Z > 0).astype(Z.dtype)


DERIVATIVES = {'tanh': tanh_derivative, 'logistic': logistic_derivative,
'relu': relu_derivative}
DERIVATIVES = {'tanh': inplace_tanh_derivative,
'logistic': inplace_logistic_derivative,
'relu': inplace_relu_derivative}


def squared_loss(y_true, y_pred):
Expand All @@ -157,14 +167,14 @@ def squared_loss(y_true, y_pred):
Parameters
----------
y_true : array-like or label indicator matrix
Ground truth (correct) labels.
Ground truth (correct) values.
y_pred : array-like or label indicator matrix
Predicted labels, as returned by a regression estimator.
Predicted values, as returned by a regression estimator.
Returns
-------
score : float
loss : float
The degree to which the samples are correctly predicted.
"""
return ((y_true - y_pred) ** 2).mean() / 2
Expand All @@ -178,13 +188,13 @@ def log_loss(y_true, y_prob):
y_true : array-like or label indicator matrix
Ground truth (correct) labels.
y_pred : array-like of float, shape = (n_samples, n_classes)
y_prob : array-like of float, shape = (n_samples, n_classes)
Predicted probabilities, as returned by a classifier's
predict_proba method.
Returns
-------
score : float
loss : float
The degree to which the samples are correctly predicted.
"""
y_prob = np.clip(y_prob, 1e-10, 1 - 1e-10)
Expand All @@ -209,19 +219,19 @@ def binary_log_loss(y_true, y_prob):
y_true : array-like or label indicator matrix
Ground truth (correct) labels.
y_pred : array-like of float, shape = (n_samples, n_classes)
y_prob : array-like of float, shape = (n_samples, n_classes)
Predicted probabilities, as returned by a classifier's
predict_proba method.
Returns
-------
score : float
loss : float
The degree to which the samples are correctly predicted.
"""
y_prob = np.clip(y_prob, 1e-10, 1 - 1e-10)

return -np.sum(y_true * np.log(y_prob) +
(1 - y_true) * np.log(1 - y_prob)) / y_prob.shape[0]
(1 - y_true) * np.log(1 - y_prob)) / y_prob.shape[0]


LOSS_FUNCTIONS = {'squared_loss': squared_loss, 'log_loss': log_loss,
Expand Down
Loading

0 comments on commit 917bacb

Please sign in to comment.