Skip to content

Commit

Permalink
BUG: Fix Akima1DInterpolator by returning linear interpolant for `y…
Browse files Browse the repository at this point in the history
….shape[0] == 2` (scipy#22278)

* BUG: interpolate: Akima1DInterpolator: use linear interpolation if there are only 2 points

reviewed at scipy#22278
  • Loading branch information
czgdp1807 authored Jan 9, 2025
1 parent 0a5c486 commit ef7437a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 34 deletions.
76 changes: 42 additions & 34 deletions scipy/interpolate/_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ class Akima1DInterpolator(CubicHermiteSpline):
.. versionadded:: 1.13.0
extrapolate : {bool, None}, optional
If bool, determines whether to extrapolate to out-of-bounds points
based on first and last intervals, or to return NaNs. If None,
If bool, determines whether to extrapolate to out-of-bounds points
based on first and last intervals, or to return NaNs. If None,
``extrapolate`` is set to False.
Methods
-------
__call__
Expand Down Expand Up @@ -491,7 +491,7 @@ class Akima1DInterpolator(CubicHermiteSpline):
"""

def __init__(self, x, y, axis=0, *, method: Literal["akima", "makima"]="akima",
def __init__(self, x, y, axis=0, *, method: Literal["akima", "makima"]="akima",
extrapolate:bool | None = None):
if method not in {"akima", "makima"}:
raise NotImplementedError(f"`method`={method} is unsupported.")
Expand All @@ -509,37 +509,45 @@ def __init__(self, x, y, axis=0, *, method: Literal["akima", "makima"]="akima",
# Akima extrapolation historically False; parent class defaults to True.
extrapolate = False if extrapolate is None else extrapolate

# determine slopes between breakpoints
m = np.empty((x.size + 3, ) + y.shape[1:])
dx = dx[(slice(None), ) + (None, ) * (y.ndim - 1)]
m[2:-2] = np.diff(y, axis=0) / dx

# add two additional points on the left ...
m[1] = 2. * m[2] - m[3]
m[0] = 2. * m[1] - m[2]
# ... and on the right
m[-2] = 2. * m[-3] - m[-4]
m[-1] = 2. * m[-2] - m[-3]

# if m1 == m2 != m3 == m4, the slope at the breakpoint is not
# defined. This is the fill value:
t = .5 * (m[3:] + m[:-3])
# get the denominator of the slope t
dm = np.abs(np.diff(m, axis=0))
if method == "makima":
pm = np.abs(m[1:] + m[:-1])
f1 = dm[2:] + 0.5 * pm[2:]
f2 = dm[:-2] + 0.5 * pm[:-2]
if y.shape[0] == 2:
# edge case: only have two points, use linear interpolation
xp = x.reshape((x.shape[0],) + (1,)*(y.ndim-1))
hk = xp[1:] - xp[:-1]
mk = (y[1:] - y[:-1]) / hk
t = np.zeros_like(y)
t[...] = mk
else:
f1 = dm[2:]
f2 = dm[:-2]
f12 = f1 + f2
# These are the mask of where the slope at breakpoint is defined:
ind = np.nonzero(f12 > 1e-9 * np.max(f12, initial=-np.inf))
x_ind, y_ind = ind[0], ind[1:]
# Set the slope at breakpoint
t[ind] = (f1[ind] * m[(x_ind + 1,) + y_ind] +
f2[ind] * m[(x_ind + 2,) + y_ind]) / f12[ind]
# determine slopes between breakpoints
m = np.empty((x.size + 3, ) + y.shape[1:])
dx = dx[(slice(None), ) + (None, ) * (y.ndim - 1)]
m[2:-2] = np.diff(y, axis=0) / dx

# add two additional points on the left ...
m[1] = 2. * m[2] - m[3]
m[0] = 2. * m[1] - m[2]
# ... and on the right
m[-2] = 2. * m[-3] - m[-4]
m[-1] = 2. * m[-2] - m[-3]

# if m1 == m2 != m3 == m4, the slope at the breakpoint is not
# defined. This is the fill value:
t = .5 * (m[3:] + m[:-3])
# get the denominator of the slope t
dm = np.abs(np.diff(m, axis=0))
if method == "makima":
pm = np.abs(m[1:] + m[:-1])
f1 = dm[2:] + 0.5 * pm[2:]
f2 = dm[:-2] + 0.5 * pm[:-2]
else:
f1 = dm[2:]
f2 = dm[:-2]
f12 = f1 + f2
# These are the mask of where the slope at breakpoint is defined:
ind = np.nonzero(f12 > 1e-9 * np.max(f12, initial=-np.inf))
x_ind, y_ind = ind[0], ind[1:]
# Set the slope at breakpoint
t[ind] = (f1[ind] * m[(x_ind + 1,) + y_ind] +
f2[ind] * m[(x_ind + 2,) + y_ind]) / f12[ind]

super().__init__(x, y, t, axis=0, extrapolate=extrapolate)
self.axis = axis
Expand Down
44 changes: 44 additions & 0 deletions scipy/interpolate/tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,50 @@ def test_eval_3d(self):
yi[:, 1, 1] = 4. * yi_
xp_assert_close(ak(xi), yi)

def test_linear_interpolant_edge_case_1d(self):
x = np.array([0.0, 1.0], dtype=float)
y = np.array([0.5, 1.0])
akima = Akima1DInterpolator(x, y, axis=0, extrapolate=None)
xp_assert_close(akima(0.45), np.array(0.725))

def test_linear_interpolant_edge_case_2d(self):
x = np.array([0., 1.])
y = np.column_stack((x, 2. * x, 3. * x, 4. * x))

ak = Akima1DInterpolator(x, y)
xi = np.array([0.5, 1.])
yi = np.array([[0.5, 1., 1.5, 2. ],
[1., 2., 3., 4.]])
xp_assert_close(ak(xi), yi)

ak = Akima1DInterpolator(x, y.T, axis=1)
xp_assert_close(ak(xi), yi.T)

def test_linear_interpolant_edge_case_3d(self):
x = np.arange(0., 2.)
y_ = np.array([0., 1.])
y = np.empty((2, 2, 2))
y[:, 0, 0] = y_
y[:, 1, 0] = 2. * y_
y[:, 0, 1] = 3. * y_
y[:, 1, 1] = 4. * y_
ak = Akima1DInterpolator(x, y)
yi_ = np.array([0.5, 1.])
yi = np.empty((2, 2, 2))
yi[:, 0, 0] = yi_
yi[:, 1, 0] = 2. * yi_
yi[:, 0, 1] = 3. * yi_
yi[:, 1, 1] = 4. * yi_
xi = yi_
xp_assert_close(ak(xi), yi)

ak = Akima1DInterpolator(x, y.transpose(1, 0, 2), axis=1)
xp_assert_close(ak(xi), yi.transpose(1, 0, 2))

ak = Akima1DInterpolator(x, y.transpose(2, 1, 0), axis=2)
xp_assert_close(ak(xi), yi.transpose(2, 1, 0))


def test_degenerate_case_multidimensional(self):
# This test is for issue #5683.
x = np.array([0, 1, 2])
Expand Down

0 comments on commit ef7437a

Please sign in to comment.