Skip to content

Commit

Permalink
MAINT: Move lstsq to umath_linalg
Browse files Browse the repository at this point in the history
This does not yet enable any broadcasting, but makes doing so in future far
easier.
  • Loading branch information
eric-wieser committed Apr 11, 2018
1 parent efc254e commit 3ef55be
Showing 1 changed file with 14 additions and 50 deletions.
64 changes: 14 additions & 50 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):
def _raise_linalgerror_svd_nonconvergence(err, flag):
raise LinAlgError("SVD did not converge")

def _raise_linalgerror_lstsq(err, flag):
raise LinAlgError("SVD did not converge in Linear Least Squares")

def get_linalg_error_extobj(callback):
extobj = list(_linalg_error_extobj) # make a copy
extobj[2] = callback
Expand Down Expand Up @@ -1997,7 +2000,6 @@ def lstsq(a, b, rcond="warn"):
>>> plt.show()
"""
import math
a, _ = _makearray(a)
b, wrap = _makearray(b)
is_1d = b.ndim == 1
Expand All @@ -2008,7 +2010,6 @@ def lstsq(a, b, rcond="warn"):
m = a.shape[0]
n = a.shape[1]
n_rhs = b.shape[1]
ldb = max(n, m)
if m != b.shape[0]:
raise LinAlgError('Incompatible dimensions')

Expand All @@ -2028,62 +2029,25 @@ def lstsq(a, b, rcond="warn"):
FutureWarning, stacklevel=2)
rcond = -1
if rcond is None:
rcond = finfo(t).eps * ldb

bstar = zeros((ldb, n_rhs), t)
bstar[:m, :n_rhs] = b
a, bstar = _fastCopyAndTranspose(t, a, bstar)
a, bstar = _to_native_byte_order(a, bstar)
s = zeros((min(m, n),), real_t)
# This line:
# * is incorrect, according to the LAPACK documentation
# * raises a ValueError if min(m,n) == 0
# * should not be calculated here anyway, as LAPACK should calculate
# `liwork` for us. But that only works if our version of lapack does
# not have this bug:
# http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00899.html
# Lapack_lite does have that bug...
nlvl = max( 0, int( math.log( float(min(m, n))/2. ) ) + 1 )
iwork = zeros((3*min(m, n)*nlvl+11*min(m, n),), fortran_int)
if isComplexType(t):
lapack_routine = lapack_lite.zgelsd
lwork = 1
rwork = zeros((lwork,), real_t)
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, -1, rwork, iwork, 0)
lrwork = int(rwork[0])
lwork = int(work[0].real)
work = zeros((lwork,), t)
rwork = zeros((lrwork,), real_t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, lwork, rwork, iwork, 0)
rcond = finfo(t).eps * max(n, m)

if m <= n:
gufunc = _umath_linalg.lstsq_m
else:
lapack_routine = lapack_lite.dgelsd
lwork = 1
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, -1, iwork, 0)
lwork = int(work[0])
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, lwork, iwork, 0)
if results['info'] > 0:
raise LinAlgError('SVD did not converge in Linear Least Squares')

# undo transpose imposed by fortran-order arrays
b_out = bstar.T
gufunc = _umath_linalg.lstsq_n

signature = 'DDd->Did' if isComplexType(t) else 'ddd->did'
extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
b_out, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)

# b_out contains both the solution and the components of the residuals
x = b_out[:n,:]
r_parts = b_out[n:,:]
x = b_out[...,:n,:]
r_parts = b_out[...,n:,:]
if isComplexType(t):
resids = sum(abs(r_parts)**2, axis=-2)
else:
resids = sum(r_parts**2, axis=-2)

rank = results['rank']

# remove the axis we added
if is_1d:
x = x.squeeze(axis=-1)
Expand Down

0 comments on commit 3ef55be

Please sign in to comment.