Skip to content

Commit

Permalink
Merge pull request pybamm-team#3578 from pipliggins/expression-tree-t…
Browse files Browse the repository at this point in the history
…yping

Adds typing to expression tree
  • Loading branch information
martinjrobins authored Feb 21, 2024
2 parents cf686e7 + eeceaa7 commit 6d85281
Show file tree
Hide file tree
Showing 37 changed files with 1,053 additions and 565 deletions.
1 change: 1 addition & 0 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _reset(self):
self.register("Sulzer2021")
self.register("Harris2020")

@staticmethod
def _caller_name():
"""
Returns the qualified name of classes that call :meth:`register` internally.
Expand Down
2 changes: 1 addition & 1 deletion pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str],
operating_conditions: list[str | tuple[str]],
period: str = "1 minute",
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
52 changes: 31 additions & 21 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#
# NumpyArray class
#
from __future__ import annotations
import numpy as np
from scipy.sparse import csr_matrix, issparse
from typing import TYPE_CHECKING

import pybamm
from pybamm.util import have_optional_dependency
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType

if TYPE_CHECKING: # pragma: no cover
import sympy


class Array(pybamm.Symbol):
Expand Down Expand Up @@ -36,13 +42,13 @@ class Array(pybamm.Symbol):

def __init__(
self,
entries,
name=None,
domain=None,
auxiliary_domains=None,
domains=None,
entries_string=None,
):
entries: np.ndarray | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
domains: DomainsType = None,
entries_string: str | None = None,
) -> None:
# if
if isinstance(entries, list):
entries = np.array(entries)
Expand All @@ -59,8 +65,6 @@ def __init__(

@classmethod
def _from_json(cls, snippet: dict):
instance = cls.__new__(cls)

if isinstance(snippet["entries"], dict):
matrix = csr_matrix(
(
Expand All @@ -73,14 +77,12 @@ def _from_json(cls, snippet: dict):
else:
matrix = snippet["entries"]

instance.__init__(
return cls(
matrix,
name=snippet["name"],
domains=snippet["domains"],
)

return instance

@property
def entries(self):
return self._entries
Expand All @@ -100,7 +102,7 @@ def entries_string(self):
return self._entries_string

@entries_string.setter
def entries_string(self, value):
def entries_string(self, value: None | tuple):
# We must include the entries in the hash, since different arrays can be
# indistinguishable by class, name and domain alone
# Slightly different syntax for sparse and non-sparse matrices
Expand All @@ -110,10 +112,10 @@ def entries_string(self, value):
entries = self._entries
if issparse(entries):
dct = entries.__dict__
self._entries_string = ["shape", str(dct["_shape"])]
entries_string = ["shape", str(dct["_shape"])]
for key in ["data", "indices", "indptr"]:
self._entries_string += [key, dct[key].tobytes()]
self._entries_string = tuple(self._entries_string)
entries_string += [key, dct[key].tobytes()]
self._entries_string = tuple(entries_string)
# self._entries_string = str(entries.__dict__)
else:
self._entries_string = (entries.tobytes(),)
Expand All @@ -124,7 +126,7 @@ def set_id(self):
(self.__class__, self.name, *self.entries_string, *tuple(self.domain))
)

def _jac(self, variable):
def _jac(self, variable) -> pybamm.Matrix:
"""See :meth:`pybamm.Symbol._jac()`."""
# Return zeros of correct size
jac = csr_matrix((self.size, variable.evaluation_array.count(True)))
Expand All @@ -139,15 +141,21 @@ def create_copy(self):
entries_string=self.entries_string,
)

def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None):
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
return self._entries

def is_constant(self):
"""See :meth:`pybamm.Symbol.is_constant()`."""
return True

def to_equation(self):
def to_equation(self) -> sympy.Array:
"""Returns the value returned by the node when evaluated."""
sympy = have_optional_dependency("sympy")
entries_list = self.entries.tolist()
Expand Down Expand Up @@ -178,7 +186,7 @@ def to_json(self):
return json_dict


def linspace(start, stop, num=50, **kwargs):
def linspace(start: float, stop: float, num: int = 50, **kwargs) -> pybamm.Array:
"""
Creates a linearly spaced array by calling `numpy.linspace` with keyword
arguments 'kwargs'. For a list of 'kwargs' see the
Expand All @@ -187,7 +195,9 @@ def linspace(start, stop, num=50, **kwargs):
return pybamm.Array(np.linspace(start, stop, num, **kwargs))


def meshgrid(x, y, **kwargs):
def meshgrid(
x: pybamm.Array, y: pybamm.Array, **kwargs
) -> tuple[pybamm.Array, pybamm.Array]:
"""
Return coordinate matrices as from coordinate vectors by calling
`numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs'
Expand Down
61 changes: 39 additions & 22 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#
# Classes and methods for averaging
#
from __future__ import annotations
from typing import Callable
import pybamm


Expand All @@ -14,13 +16,19 @@ class _BaseAverage(pybamm.Integral):
The child node
"""

def __init__(self, child, name, integration_variable):
def __init__(
self,
child: pybamm.Symbol,
name: str,
integration_variable: list[pybamm.IndependentVariable]
| pybamm.IndependentVariable,
) -> None:
super().__init__(child, integration_variable)
self.name = name


class XAverage(_BaseAverage):
def __init__(self, child):
def __init__(self, child: pybamm.Symbol) -> None:
if all(n in child.domain[0] for n in ["negative", "particle"]):
x = pybamm.standard_spatial_vars.x_n
elif all(n in child.domain[0] for n in ["positive", "particle"]):
Expand All @@ -30,56 +38,60 @@ def __init__(self, child):
integration_variable = x
super().__init__(child, "x-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return x_average(child)


class YZAverage(_BaseAverage):
def __init__(self, child):
def __init__(self, child: pybamm.Symbol) -> None:
y = pybamm.standard_spatial_vars.y
z = pybamm.standard_spatial_vars.z
integration_variable = [y, z]
integration_variable: list[pybamm.IndependentVariable] = [y, z]
super().__init__(child, "yz-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return yz_average(child)


class ZAverage(_BaseAverage):
def __init__(self, child):
integration_variable = [pybamm.standard_spatial_vars.z]
def __init__(self, child: pybamm.Symbol) -> None:
integration_variable: list[pybamm.IndependentVariable] = [
pybamm.standard_spatial_vars.z
]
super().__init__(child, "z-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return z_average(child)


class RAverage(_BaseAverage):
def __init__(self, child):
integration_variable = [pybamm.SpatialVariable("r", child.domain)]
def __init__(self, child: pybamm.Symbol) -> None:
integration_variable: list[pybamm.IndependentVariable] = [
pybamm.SpatialVariable("r", child.domain)
]
super().__init__(child, "r-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return r_average(child)


class SizeAverage(_BaseAverage):
def __init__(self, child, f_a_dist):
def __init__(self, child: pybamm.Symbol, f_a_dist) -> None:
R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian")
integration_variable = [R]
integration_variable: list[pybamm.IndependentVariable] = [R]
super().__init__(child, "size-average", integration_variable)
self.f_a_dist = f_a_dist

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return size_average(child, f_a_dist=self.f_a_dist)


def x_average(symbol):
def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the x-direction.
Expand Down Expand Up @@ -168,7 +180,7 @@ def x_average(symbol):
return XAverage(symbol)


def z_average(symbol):
def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the z-direction.
Expand Down Expand Up @@ -205,7 +217,7 @@ def z_average(symbol):
return ZAverage(symbol)


def yz_average(symbol):
def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the y-z-direction.
Expand Down Expand Up @@ -239,11 +251,11 @@ def yz_average(symbol):
return YZAverage(symbol)


def xyz_average(symbol):
def xyz_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
return yz_average(x_average(symbol))


def r_average(symbol):
def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the r-direction.
Expand Down Expand Up @@ -286,7 +298,9 @@ def r_average(symbol):
return RAverage(symbol)


def size_average(symbol, f_a_dist=None):
def size_average(
symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None
) -> pybamm.Symbol:
"""Convenience function for averaging over particle size R using the area-weighted
particle-size distribution.
Expand Down Expand Up @@ -339,7 +353,10 @@ def size_average(symbol, f_a_dist=None):
return SizeAverage(symbol, f_a_dist)


def _sum_of_averages(symbol, average_function):
def _sum_of_averages(
symbol: pybamm.Addition | pybamm.Subtraction,
average_function: Callable[[pybamm.Symbol], pybamm.Symbol],
):
if isinstance(symbol, pybamm.Addition):
return average_function(symbol.left) + average_function(symbol.right)
elif isinstance(symbol, pybamm.Subtraction):
Expand Down
Loading

0 comments on commit 6d85281

Please sign in to comment.