Skip to content

Commit

Permalink
Merge pull request scikit-learn#7393 from ogrisel/mlp-multilabel
Browse files Browse the repository at this point in the history
[MRG] Fix for MLP multilabel predict_proba and removal of decision_function
  • Loading branch information
ogrisel authored Sep 12, 2016
2 parents b966b44 + 1e8ca7d commit 5a018a3
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 90 deletions.
24 changes: 8 additions & 16 deletions doc/modules/neural_networks_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ contains the weight matrices that constitute the model parameters::
>>> [coef.shape for coef in clf.coefs_]
[(2, 5), (5, 2), (2, 1)]

To get the raw values before applying the output activation function, run the
following command,

use :meth:`MLPClassifier.decision_function`::

>>> clf.decision_function([[2., 2.], [1., 2.]]) # doctest: +ELLIPSIS
array([ 47.6..., 47.6...])

Currently, :class:`MLPClassifier` supports only the
Cross-Entropy loss function, which allows probability estimates by running the
``predict_proba`` method.
Expand All @@ -125,23 +117,23 @@ classification, it minimizes the Cross-Entropy loss function, giving a vector
of probability estimates :math:`P(y|x)` per sample :math:`x`::

>>> clf.predict_proba([[2., 2.], [1., 2.]]) # doctest: +ELLIPSIS
array([[ 0., 1.],
[ 0., 1.]])
array([[ 1.967...e-04, 9.998...-01],
[ 1.967...e-04, 9.998...-01]])

:class:`MLPClassifier` supports multi-class classification by
applying `Softmax <https://en.wikipedia.org/wiki/Softmax_activation_function>`_
as the output function.

Further, the algorithm supports :ref:`multi-label classification <multiclass>`
in which a sample can belong to more than one class. For each class, the output
of :meth:`MLPClassifier.decision_function` passes through the
logistic function. Values larger or equal to `0.5` are rounded to `1`,
otherwise to `0`. For a predicted output of a sample, the indices where the
value is `1` represents the assigned classes of that sample::
in which a sample can belong to more than one class. For each class, the raw
output passes through the logistic function. Values larger or equal to `0.5`
are rounded to `1`, otherwise to `0`. For a predicted output of a sample, the
indices where the value is `1` represents the assigned classes of that sample::

>>> X = [[0., 0.], [1., 1.]]
>>> y = [[0, 1], [1, 1]]
>>> clf = MLPClassifier(algorithm='l-bfgs', alpha=1e-5, hidden_layer_sizes=(15,), random_state=1)
>>> clf = MLPClassifier(algorithm='l-bfgs', alpha=1e-5,
... hidden_layer_sizes=(15,), random_state=1)
>>> clf.fit(X, y)
MLPClassifier(activation='relu', algorithm='l-bfgs', alpha=1e-05,
batch_size='auto', beta_1=0.9, beta_2=0.999, early_stopping=False,
Expand Down
95 changes: 41 additions & 54 deletions sklearn/neural_network/multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import warnings

from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ._base import logistic, softmax
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
from ._stochastic_optimizers import SGDOptimizer, AdamOptimizer
from ..model_selection import train_test_split
Expand Down Expand Up @@ -81,7 +80,7 @@ def _unpack(self, packed_parameters):
start, end = self._intercept_indptr[i]
self.intercepts_[i] = packed_parameters[start:end]

def _forward_pass(self, activations, with_output_activation=True):
def _forward_pass(self, activations):
"""Perform a forward pass on the network by computing the values
of the neurons in the hidden layers and the output layer.
Expand All @@ -107,9 +106,8 @@ def _forward_pass(self, activations, with_output_activation=True):
activations[i + 1] = hidden_activation(activations[i + 1])

# For the last layer
if with_output_activation:
output_activation = ACTIVATIONS[self.out_activation_]
activations[i + 1] = output_activation(activations[i + 1])
output_activation = ACTIVATIONS[self.out_activation_]
activations[i + 1] = output_activation(activations[i + 1])

return activations

Expand Down Expand Up @@ -222,7 +220,10 @@ def _backprop(self, X, y, activations, deltas, coef_grads,
activations = self._forward_pass(activations)

# Get loss
loss = LOSS_FUNCTIONS[self.loss](y, activations[-1])
loss_func_name = self.loss
if loss_func_name == 'log_loss' and self.out_activation_ == 'logistic':
loss_func_name = 'binary_log_loss'
loss = LOSS_FUNCTIONS[loss_func_name](y, activations[-1])
# Add L2 regularization term to loss
values = np.sum(
np.array([np.dot(s.ravel(), s.ravel()) for s in self.coefs_]))
Expand Down Expand Up @@ -272,18 +273,14 @@ def _initialize(self, y, layer_units):
# Output for binary class and multi-label
else:
self.out_activation_ = 'logistic'
if self.loss == 'log_loss':
self.loss = 'binary_log_loss'

# Initialize coefficient and intercept layers
self.coefs_ = []
self.intercepts_ = []

for i in range(self.n_layers_ - 1):
rng = check_random_state(self.random_state)
coef_init, intercept_init = self._init_coef(layer_units[i],
layer_units[i + 1],
rng)
layer_units[i + 1])
self.coefs_.append(coef_init)
self.intercepts_.append(intercept_init)

Expand All @@ -296,7 +293,7 @@ def _initialize(self, y, layer_units):
else:
self.best_loss_ = np.inf

def _init_coef(self, fan_in, fan_out, rng):
def _init_coef(self, fan_in, fan_out):
if self.activation == 'logistic':
# Use the initialization method recommended by
# Glorot et al.
Expand All @@ -308,8 +305,10 @@ def _init_coef(self, fan_in, fan_out, rng):
raise ValueError("Unknown activation function %s" %
self.activation)

coef_init = rng.uniform(-init_bound, init_bound, (fan_in, fan_out))
intercept_init = rng.uniform(-init_bound, init_bound, fan_out)
coef_init = self._random_state.uniform(-init_bound, init_bound,
(fan_in, fan_out))
intercept_init = self._random_state.uniform(-init_bound, init_bound,
fan_out)
return coef_init, intercept_init

def _fit(self, X, y, incremental=False):
Expand Down Expand Up @@ -337,6 +336,9 @@ def _fit(self, X, y, incremental=False):
layer_units = ([n_features] + hidden_layer_sizes +
[self.n_outputs_])

# check random state
self._random_state = check_random_state(self.random_state)

if not hasattr(self, 'coefs_') or (not self.warm_start and not
incremental):
# First time training the model
Expand Down Expand Up @@ -419,9 +421,11 @@ def _validate_hyperparameters(self):
if self.learning_rate not in ["constant", "invscaling", "adaptive"]:
raise ValueError("learning rate %s is not supported. " %
self.learning_rate)
if self.algorithm not in _STOCHASTIC_ALGOS + ["l-bfgs"]:
raise ValueError("The algorithm %s is not supported. " %
self.algorithm)
supported_algorithms = _STOCHASTIC_ALGOS + ["l-bfgs"]
if self.algorithm not in supported_algorithms:
raise ValueError("The algorithm %s is not supported. "
" Expected one of: %s" %
(self.algorithm, ", ".join(supported_algorithms)))

def _fit_lbfgs(self, X, y, activations, deltas, coef_grads,
intercept_grads, layer_units):
Expand Down Expand Up @@ -465,7 +469,6 @@ def _fit_lbfgs(self, X, y, activations, deltas, coef_grads,

def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
intercept_grads, layer_units, incremental):
rng = check_random_state(self.random_state)

if not incremental or not hasattr(self, '_optimizer'):
params = self.coefs_ + self.intercepts_
Expand All @@ -483,7 +486,7 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
early_stopping = self.early_stopping and not incremental
if early_stopping:
X, X_val, y, y_val = train_test_split(
X, y, random_state=self.random_state,
X, y, random_state=self._random_state,
test_size=self.validation_fraction)
if isinstance(self, ClassifierMixin):
y_val = self.label_binarizer_.inverse_transform(y_val)
Expand All @@ -500,7 +503,7 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads,

try:
for it in range(self.max_iter):
X, y = shuffle(X, y, random_state=rng)
X, y = shuffle(X, y, random_state=self._random_state)
accumulated_loss = 0.0
for batch_slice in gen_batches(n_samples, batch_size):
activations[0] = X[batch_slice]
Expand Down Expand Up @@ -629,14 +632,14 @@ def partial_fit(self):
"""
if self.algorithm not in _STOCHASTIC_ALGOS:
raise AttributeError("partial_fit is only available for stochastic"
"optimization algorithms. %s is not"
" optimization algorithms. %s is not"
" stochastic" % self.algorithm)
return self._partial_fit

def _partial_fit(self, X, y, classes=None):
return self._fit(X, y, incremental=True)

def _decision_scores(self, X):
def _predict(self, X):
"""Predict using the trained model
Parameters
Expand Down Expand Up @@ -667,7 +670,7 @@ def _decision_scores(self, X):
activations.append(np.empty((X.shape[0],
layer_units[i + 1])))
# forward propagate
self._forward_pass(activations, with_output_activation=False)
self._forward_pass(activations)
y_pred = activations[-1]

return y_pred
Expand Down Expand Up @@ -913,27 +916,6 @@ def _validate_input(self, X, y, incremental):
y = self.label_binarizer_.transform(y)
return X, y

def decision_function(self, X):
"""Decision function of the mlp model
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The input data.
Returns
-------
y : array-like, shape (n_samples,) or (n_samples, n_classes)
The values of decision function for each class in the model.
"""
check_is_fitted(self, "coefs_")
y_scores = self._decision_scores(X)

if self.n_outputs_ == 1:
return y_scores.ravel()
else:
return y_scores

def predict(self, X):
"""Predict using the multi-layer perceptron classifier
Expand All @@ -948,10 +930,12 @@ def predict(self, X):
The predicted classes.
"""
check_is_fitted(self, "coefs_")
y_scores = self.decision_function(X)
y_scores = ACTIVATIONS[self.out_activation_](y_scores)
y_pred = self._predict(X)

return self.label_binarizer_.inverse_transform(y_scores)
if self.n_outputs_ == 1:
y_pred = y_pred.ravel()

return self.label_binarizer_.inverse_transform(y_pred)

@property
def partial_fit(self):
Expand Down Expand Up @@ -979,7 +963,7 @@ def partial_fit(self):
"""
if self.algorithm not in _STOCHASTIC_ALGOS:
raise AttributeError("partial_fit is only available for stochastic"
"optimization algorithms. %s is not"
" optimization algorithms. %s is not"
" stochastic" % self.algorithm)
return self._partial_fit

Expand Down Expand Up @@ -1022,13 +1006,16 @@ def predict_proba(self, X):
The predicted probability of the sample for each class in the
model, where classes are ordered as they are in `self.classes_`.
"""
y_scores = self.decision_function(X)
check_is_fitted(self, "coefs_")
y_pred = self._predict(X)

if self.n_outputs_ == 1:
y_pred = y_pred.ravel()

if y_scores.ndim == 1:
y_scores = logistic(y_scores)
return np.vstack([1 - y_scores, y_scores]).T
if y_pred.ndim == 1:
return np.vstack([1 - y_pred, y_pred]).T
else:
return softmax(y_scores)
return y_pred


class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
Expand Down Expand Up @@ -1258,7 +1245,7 @@ def predict(self, X):
The predicted values.
"""
check_is_fitted(self, "coefs_")
y_pred = self._decision_scores(X)
y_pred = self._predict(X)
if y_pred.shape[1] == 1:
return y_pred.ravel()
return y_pred
Expand Down
Loading

0 comments on commit 5a018a3

Please sign in to comment.