Skip to content

Commit

Permalink
allow evaluation of an interpolant with a float
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 26, 2024
1 parent cba763a commit 5fd6e71
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from scipy import interpolate
from typing import Sequence

import numbers

import pybamm

Expand Down Expand Up @@ -102,8 +102,8 @@ def __init__(
"len(x1) should equal y=shape[0], "
f"but x1.shape={x1.shape} and y.shape={y.shape}"
)
# children should be a list not a symbol
if isinstance(children, pybamm.Symbol):
# children should be a list not a symbol or a number
if isinstance(children, (pybamm.Symbol, numbers.Number)):
children = [children]
# Either a single x is provided and there is one child
# or x is a 2-tuple and there are two children
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def test_interpolation(self):
interp.evaluate(y=np.array([2]))[:, 0], np.array([np.nan])
)

def test_interpolation_float(self):
x = np.linspace(0, 1, 200)
interp = pybamm.Interpolant(x, 2 * x, 0.5)
assert interp.evaluate() == 1.0

def test_interpolation_1_x_2d_y(self):
x = np.linspace(0, 1, 200)
y = np.tile(2 * x, (10, 1)).T
Expand Down

0 comments on commit 5fd6e71

Please sign in to comment.