Skip to content

Commit

Permalink
Fix some ignored type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
pipliggins committed Jan 25, 2024
1 parent 21e4107 commit d1ae819
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str],
operating_conditions: list[str | tuple[str]],
period: str = "1 minute",
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
for cycle in operating_conditions:
# Check types and convert to list
if not isinstance(cycle, tuple):
cycle = (cycle,) # type: ignore[assignment]
cycle = (cycle,)
operating_conditions_cycles.append(cycle)

self.operating_conditions_cycles = operating_conditions_cycles
Expand Down
12 changes: 6 additions & 6 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,18 +590,18 @@ def full_like(symbols: tuple[pybamm.Symbol, ...], fill_value: float) -> pybamm.S
return pybamm.Scalar(fill_value)
try:
shape = sum_symbol.shape
# use vector or matrix
if shape[1] == 1:
array_type: type[pybamm.Vector] = pybamm.Vector
else:
array_type: type[pybamm.Matrix] = pybamm.Matrix # type:ignore[no-redef]

# return dense array, except for a matrix of zeros
if shape[1] != 1 and fill_value == 0:
entries = csr_matrix(shape)
else:
entries = fill_value * np.ones(shape)

return array_type(entries, domains=sum_symbol.domains)
# use vector or matrix
if shape[1] == 1:
return pybamm.Vector(entries, domains=sum_symbol.domains)
else:
return pybamm.Matrix(entries, domains=sum_symbol.domains)

except NotImplementedError:
if (
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def diff(self, variable: pybamm.Symbol):
return pybamm.Scalar(1)
else:
children = self.orphans
partial_derivatives = [None] * len(children)
partial_derivatives: list[None | pybamm.Symbol] = [None] * len(children)
for i, child in enumerate(self.children):
# if variable appears in the function, differentiate
# function, and apply chain rule
Expand All @@ -87,9 +87,9 @@ def diff(self, variable: pybamm.Symbol):
# remove None entries
partial_derivatives = [x for x in partial_derivatives if x is not None]

derivative = sum(partial_derivatives) # type: ignore[arg-type]
derivative = sum(partial_derivatives)
if derivative == 0:
derivative = pybamm.Scalar(0) # type: ignore[assignment]
return pybamm.Scalar(0)

return derivative

Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def __init__(
self.dimension = 1
if interpolator == "linear":
if extrapolate is False:
fill_value = np.nan
fill_value_1: float | str = np.nan
elif extrapolate is True:
fill_value = "extrapolate" # type: ignore[assignment]
fill_value_1 = "extrapolate"
interpolating_function = interpolate.interp1d(
x1,
y.T,
bounds_error=False,
fill_value=fill_value,
fill_value=fill_value_1,
)
elif interpolator == "cubic":
interpolating_function = interpolate.CubicSpline(
Expand Down
4 changes: 1 addition & 3 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def dot_product(self, b):
"""
# assume b is a column vector
result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
return result.at[self.row].add(
self.data.reshape(-1, 1) * b[self.col] # type:ignore[index]
)
return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])

def scalar_multiply(self, b: float):
"""
Expand Down
9 changes: 6 additions & 3 deletions pybamm/expression_tree/operations/serialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,15 @@ def _get_pybamm_class(self, snippet: dict):
try:
empty_class = self._Empty()
empty_class.__class__ = class_

return empty_class

except TypeError:
# Mesh objects have a different layouts
empty_class = self._EmptyDict() # type: ignore[assignment]
empty_class.__class__ = class_
empty_dict_class = self._EmptyDict()
empty_dict_class.__class__ = class_

return empty_class
return empty_dict_class

def _deconstruct_pybamm_dicts(self, dct: dict):
"""
Expand Down

0 comments on commit d1ae819

Please sign in to comment.