Skip to content

Commit

Permalink
[MRG+2] Faster isotonic rebased (scikit-learn#7444)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel authored Sep 16, 2016
1 parent 1054e07 commit f3baca6
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 69 deletions.
18 changes: 15 additions & 3 deletions benchmarks/bench_isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,26 @@


def generate_perturbed_logarithm_dataset(size):
return np.random.randint(-50, 50, size=n) \
+ 50. * np.log(1 + np.arange(n))
return (np.random.randint(-50, 50, size=size) +
50. * np.log(1 + np.arange(size)))


def generate_logistic_dataset(size):
X = np.sort(np.random.normal(size=size))
return np.random.random(size=size) < 1.0 / (1.0 + np.exp(-X))


def generate_pathological_dataset(size):
# Triggers O(n^2) complexity on the original implementation.
return np.r_[np.arange(size),
np.arange(-(size - 1), size),
np.arange(-(size - 1), 1)]


DATASET_GENERATORS = {
'perturbed_logarithm': generate_perturbed_logarithm_dataset,
'logistic': generate_logistic_dataset
'logistic': generate_logistic_dataset,
'pathological': generate_pathological_dataset,
}


Expand All @@ -53,6 +61,8 @@ def bench_isotonic_regression(Y):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Isotonic Regression benchmark tool")
parser.add_argument('--seed', type=int,
help="RNG seed")
parser.add_argument('--iterations', type=int, required=True,
help="Number of iterations to average timings over "
"for each problem size")
Expand All @@ -67,6 +77,8 @@ def bench_isotonic_regression(Y):

args = parser.parse_args()

np.random.seed(args.seed)

timings = []
for exponent in range(args.log_min_problem_size,
args.log_max_problem_size):
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ Enhancements
(`#7419 <https://github.com/scikit-learn/scikit-learn/pull/7419>_`)
By `Gregory Stupp`_ and `Joel Nothman`_.

- Isotonic regression (:mod:`isotonic`) now uses a better algorithm to avoid
`O(n^2)` behavior in pathological cases, and is also generally faster
(`#6601 <https://github.com/scikit-learn/scikit-learn/pull/6691>`).
By `Antony Lee`_.


Bug fixes
.........

Expand Down
98 changes: 46 additions & 52 deletions sklearn/_isotonic.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Author: Nelle Varoquaux, Andrew Tulloch
# Author: Nelle Varoquaux, Andrew Tulloch, Antony Lee

# Uses the pool adjacent violators algorithm (PAVA), with the
# enhancement of searching for the longest decreasing subsequence to
Expand All @@ -10,73 +10,67 @@ cimport cython

ctypedef np.float64_t DOUBLE

np.import_array()


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _isotonic_regression(np.ndarray[DOUBLE, ndim=1] y,
np.ndarray[DOUBLE, ndim=1] weight,
np.ndarray[DOUBLE, ndim=1] solution):
def _inplace_contiguous_isotonic_regression(DOUBLE[::1] y, DOUBLE[::1] w):
cdef:
DOUBLE numerator, denominator, ratio
Py_ssize_t i, pooled, n, k

n = y.shape[0]
# The algorithm proceeds by iteratively updating the solution
# array.
Py_ssize_t n = y.shape[0], i, k
DOUBLE prev_y, sum_wy, sum_w
Py_ssize_t[::1] target = np.arange(n, dtype=np.intp)

# TODO - should we just pass in a pre-copied solution
# array and mutate that?
for i in range(n):
solution[i] = y[i]
# target describes a list of blocks. At any time, if [i..j] (inclusive) is
# an active block, then target[i] := j and target[j] := i.

if n <= 1:
return solution
# For "active" indices (block starts):
# w[i] := sum{w_orig[j], j=[i..target[i]]}
# y[i] := sum{y_orig[j]*w_orig[j], j=[i..target[i]]} / w[i]

n -= 1
while 1:
# repeat until there are no more adjacent violators.
with nogil:
i = 0
pooled = 0
while i < n:
k = i
while k < n and solution[k] >= solution[k + 1]:
k += 1
if solution[i] != solution[k]:
# solution[i:k + 1] is a decreasing subsequence, so
# replace each point in the subsequence with the
# weighted average of the subsequence.

# TODO: explore replacing each subsequence with a
# _single_ weighted point, and reconstruct the whole
# sequence from the sequence of collapsed points.
# Theoretically should reduce running time, though
# initial experiments weren't promising.
numerator = 0.0
denominator = 0.0
for j in range(i, k + 1):
numerator += solution[j] * weight[j]
denominator += weight[j]
ratio = numerator / denominator
for j in range(i, k + 1):
solution[j] = ratio
pooled = 1
i = k + 1
# Check for convergence
if pooled == 0:
break

return solution
k = target[i] + 1
if k == n:
break
if y[i] < y[k]:
i = k
continue
sum_wy = w[i] * y[i]
sum_w = w[i]
while True:
# We are within a decreasing subsequence.
prev_y = y[k]
sum_wy += w[k] * y[k]
sum_w += w[k]
k = target[k] + 1
if k == n or prev_y < y[k]:
# Non-singleton decreasing subsequence is finished,
# update first entry.
y[i] = sum_wy / sum_w
w[i] = sum_w
target[i] = k - 1
target[k - 1] = i
if i > 0:
# Backtrack if we can. This makes the algorithm
# single-pass and ensures O(n) complexity.
i = target[i - 1]
# Otherwise, restart from the same point.
break
# Reconstruct the solution.
i = 0
while i < n:
k = target[i] + 1
y[i + 1 : k] = y[i]
i = k


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _make_unique(np.ndarray[dtype=np.float64_t] X,
np.ndarray[dtype=np.float64_t] y,
np.ndarray[dtype=np.float64_t] sample_weights):
np.ndarray[dtype=np.float64_t] y,
np.ndarray[dtype=np.float64_t] sample_weights):
"""Average targets for duplicate X, drop duplicates.
Aggregates duplicate X values into a single X value where
Expand Down
22 changes: 8 additions & 14 deletions sklearn/isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .utils import as_float_array, check_array, check_consistent_length
from .utils import deprecated
from .utils.fixes import astype
from ._isotonic import _isotonic_regression, _make_unique
from ._isotonic import _inplace_contiguous_isotonic_regression, _make_unique
import warnings
import math

Expand Down Expand Up @@ -120,28 +120,22 @@ def isotonic_regression(y, sample_weight=None, y_min=None, y_max=None,
"Active set algorithms for isotonic regression; A unifying framework"
by Michael J. Best and Nilotpal Chakravarti, section 3.
"""
y = np.asarray(y, dtype=np.float64)
order = np.s_[:] if increasing else np.s_[::-1]
y = np.array(y[order], dtype=np.float64)
if sample_weight is None:
sample_weight = np.ones(len(y), dtype=y.dtype)
sample_weight = np.ones(len(y), dtype=np.float64)
else:
sample_weight = np.asarray(sample_weight, dtype=np.float64)
if not increasing:
y = y[::-1]
sample_weight = sample_weight[::-1]

solution = np.empty(len(y))
y_ = _isotonic_regression(y, sample_weight, solution)
if not increasing:
y_ = y_[::-1]
sample_weight = np.array(sample_weight[order], dtype=np.float64)

_inplace_contiguous_isotonic_regression(y, sample_weight)
if y_min is not None or y_max is not None:
# Older versions of np.clip don't accept None as a bound, so use np.inf
if y_min is None:
y_min = -np.inf
if y_max is None:
y_max = np.inf
np.clip(y_, y_min, y_max, y_)
return y_
np.clip(y, y_min, y_max, y)
return y[order]


class IsotonicRegression(BaseEstimator, TransformerMixin, RegressorMixin):
Expand Down
4 changes: 4 additions & 0 deletions sklearn/tests/test_isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def test_isotonic_regression():
y_ = np.array([3, 6, 6, 8, 8, 8, 10])
assert_array_equal(y_, isotonic_regression(y))

y = np.array([10, 0, 2])
y_ = np.array([4, 4, 4])
assert_array_equal(y_, isotonic_regression(y))

x = np.arange(len(y))
ir = IsotonicRegression(y_min=0., y_max=1.)
ir.fit(x, y)
Expand Down

0 comments on commit f3baca6

Please sign in to comment.