Skip to content

Commit

Permalink
downloaded new task 5
Browse files Browse the repository at this point in the history
  • Loading branch information
ludvigls committed Oct 9, 2020
1 parent b735dbd commit 21e9a37
Show file tree
Hide file tree
Showing 8 changed files with 619 additions and 1 deletion.
1 change: 0 additions & 1 deletion 5
Submodule 5 deleted from 7e3a22
Binary file added ExPDA/data_for_pda.mat
Binary file not shown.
49 changes: 49 additions & 0 deletions ExPDA/estimatorduck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#%%
from typing import Dict, Any, Generic, TypeVar
from typing_extensions import Protocol, runtime

from mixturedata import MixtureParameters
from gaussparams import GaussParams

import numpy as np


T = TypeVar("T")


@runtime
class StateEstimator(Protocol[T]):
def predict(self, eststate: T, Ts: float) -> T:
...

def update(
self, z: np.ndarray, eststate: T, *, sensor_state: Dict[str, Any] = None
) -> T:
...

def step(self, z: np.ndarray, eststate: T, Ts: float) -> T:
...

def estimate(self, estastate: T) -> GaussParams:
...

def init_filter_state(self, init: Any) -> T:
...

def loglikelihood(
self, z: np.ndarray, eststate: T, *, sensor_state: Dict[str, Any] = None
) -> float:
...

def reduce_mixture(self, estimator_mixture: MixtureParameters[T]) -> T:
...

def gate(
self,
z: np.ndarray,
eststate: T,
gate_size_square: float,
*,
sensor_state: Dict[str, Any] = None
) -> bool:
...
69 changes: 69 additions & 0 deletions ExPDA/gaussparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Optional, Union, Tuple
from dataclasses import dataclass
from mytypes import ArrayLike
import numpy as np


@dataclass(init=False)
class GaussParams:
"""A class for holding Gaussian parameters"""

__slots__ = ["mean", "cov"]
mean: np.ndarray # shape=(n,)
cov: np.ndarray # shape=(n, n)

def __init__(self, mean: ArrayLike, cov: ArrayLike) -> None:
self.mean = np.asarray(mean, dtype=float)
self.cov = np.asarray(cov, dtype=float)

def __iter__(self): # in order to use tuple unpacking
return iter((self.mean, self.cov))


@dataclass(init=False)
class GaussParamList:
__slots__ = ["mean", "cov"]
mean: np.ndarray # shape=(N, n)
cov: np.ndarray # shape=(N, n, n)

def __init__(self, mean=None, cov=None):
if mean is not None and cov is not None:
self.mean = mean
self.cov = cov
else:
# container left empty
pass

@classmethod
def allocate(
cls,
shape: Union[int, Tuple[int, ...]], # list shape
n: int, # dimension
fill: Optional[float] = None, # fill the allocated arrays
) -> "GaussParamList":
if isinstance(shape, int):
shape = (shape,)

if fill is None:
return cls(np.empty((*shape, n)), np.empty((*shape, n, n)))
else:
return cls(np.full((*shape, n), fill), np.full((*shape, n, n), fill))

def __getitem__(self, key):
theCls = GaussParams if isinstance(key, int) else GaussParamList
return theCls(self.mean[key], self.cov[key])

def __setitem__(self, key, value):
if isinstance(value, (GaussParams, tuple)):
self.mean[key], self.cov[key] = value
elif isinstance(value, GaussParamList):
self.mean[key] = value.mean
self.cov[key] = value.cov
else:
raise NotImplementedError(f"Cannot set from type {value}")

def __len__(self):
return self.mean.shape[0]

def __iter__(self):
yield from (self[k] for k in range(len(self)))
64 changes: 64 additions & 0 deletions ExPDA/mixturedata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import (
Collection,
Generic,
TypeVar,
Union,
Sequence,
Any,
List,
)

# from singledispatchmethod import singledispatchmethod # pip install
from dataclasses import dataclass
import numpy as np

T = TypeVar("T")


@dataclass
class MixtureParameters(Generic[T]):
__slots__ = ["weights", "components"]
weights: np.ndarray
components: Sequence[T]


# class Array(Collection[T], Generic[T]):
# def __getitem__(self, key):
# ...

# def __setitem__(self, key, vaule):
# ...


# @dataclass
# class MixtureParametersList(Generic[T]):
# weights: np.ndarray
# components: Array[Sequence[T]]

# @classmethod
# def allocate(cls, shape: Union[int, Tuple[int, ...]], component_type: T):
# shape = (shape,) if isinstance(shape, int) else shape
# # TODO
# raise NotImplementedError

# @singledispatchmethod
# def __getitem__(self, key: Any) -> "MixtureParametersList[T]":
# return MixtureParametersList(self.weights[key], self.components[key])

# @__getitem__.register
# def _(self, key: int) -> MixtureParameters:
# return MixtureParameters(self.weights[key], self.components[key])

# def __setitem__(
# self,
# key: Union[int, slice],
# value: "Union[MixtureParameters[T], MixtureParametersList[T]]",
# ) -> None:
# self.weights[key] = value.weights
# self.components[key] = value.components

# def __len__(self):
# return self.weights.shape[0]

# def __iter__(self):
# yield from (self[k] for k in range(len(self)))
67 changes: 67 additions & 0 deletions ExPDA/mytypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import sys
from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, Union, overload

# %% Taken from https://github.com/numpy/numpy/tree/master/numpy/typing
from numpy import dtype, ndarray

if sys.version_info >= (3, 8):
from typing import Protocol, TypedDict
HAVE_PROTOCOL = True
else:
try:
from typing_extensions import Protocol, TypedDict
except ImportError:
HAVE_PROTOCOL = False
else:
HAVE_PROTOCOL = True

_Shape = Tuple[int, ...]

# Anything that can be coerced to a shape tuple
_ShapeLike = Union[int, Sequence[int]]

_DtypeLikeNested = Any # TODO: wait for support for recursive types

if TYPE_CHECKING or HAVE_PROTOCOL:
# Mandatory keys
class _DtypeDictBase(TypedDict):
names: Sequence[str]
formats: Sequence[_DtypeLikeNested]

# Mandatory + optional keys
class _DtypeDict(_DtypeDictBase, total=False):
offsets: Sequence[int]
# Only `str` elements are usable as indexing aliases, but all objects are legal
titles: Sequence[Any]
itemsize: int
aligned: bool

# A protocol for anything with the dtype attribute
class _SupportsDtype(Protocol):
dtype: _DtypeLikeNested

else:
_DtypeDict = Any
_SupportsDtype = Any


DtypeLike = Union[
dtype, None, type, _SupportsDtype, str, Tuple[_DtypeLikeNested, int],
Tuple[_DtypeLikeNested, _ShapeLike], List[Any], _DtypeDict,
Tuple[_DtypeLikeNested, _DtypeLikeNested],
]


if TYPE_CHECKING or HAVE_PROTOCOL:
class _SupportsArray(Protocol):
@overload
def __array__(self, __dtype: DtypeLike = ...) -> ndarray: ...
@overload
def __array__(self, dtype: DtypeLike = ...) -> ndarray: ...
else:
_SupportsArray = Any


ArrayLike = Union[bool, int, float, complex, _SupportsArray, Sequence]

# %%
Loading

0 comments on commit 21e9a37

Please sign in to comment.