Skip to content

Commit

Permalink
Fix init value to speed up compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
liukidar committed May 28, 2024
1 parent c9c7bb7 commit 4e3ac11
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 40 deletions.
112 changes: 73 additions & 39 deletions pcax/core/_parameter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
__all__ = [
"BaseParam",
"Param",
"ParamDict",
"ParamCache",
"get",
"set"
]
__all__ = ["BaseParam", "Param", "ParamDict", "ParamCache", "get", "set"]


import abc
import re
from typing import Tuple, Dict, Any, Type
import functools

Expand Down Expand Up @@ -42,29 +34,29 @@ def __new__(mcs, name, bases, dct):
_cls,
flatten_func=_BaseParamMeta.flatten_parameter,
flatten_with_keys=_BaseParamMeta.flatten_parameter_with_keys,
unflatten_func=functools.partial(_BaseParamMeta.unflatten_parameter, cls=_cls)
unflatten_func=functools.partial(_BaseParamMeta.unflatten_parameter, cls=_cls),
)

return _cls

@staticmethod
def flatten_parameter(param: 'BaseParam') -> Tuple[Any, Dict[str, Any]]:
def flatten_parameter(param: "BaseParam") -> Tuple[Any, Dict[str, Any]]:
_aux_data = dict.copy(param.__dict__)
del _aux_data["_value"]

return (param._value,), _aux_data

@staticmethod
def flatten_parameter_with_keys(param: 'BaseParam') -> Tuple[Any, Dict[str, Any]]:
def flatten_parameter_with_keys(param: "BaseParam") -> Tuple[Any, Dict[str, Any]]:
_aux_data = dict.copy(param.__dict__)
del _aux_data["_value"]

return ((jax.tree_util.GetAttrKey("value"), param._value),), _aux_data

@staticmethod
def unflatten_parameter(aux_data: Dict[str, Any], children: Any, *, cls: Type['BaseParam']) -> 'BaseParam':
def unflatten_parameter(aux_data: Dict[str, Any], children: Any, *, cls: Type["BaseParam"]) -> "BaseParam":
_param = object.__new__(cls)

_param.__dict__ = dict.copy(aux_data)
_param._value = children[0]

Expand All @@ -76,10 +68,7 @@ class BaseParam(metaclass=_BaseParamMeta):
Base abstract class for all parameters. It is used to detect whether an object is a parameter or not.
"""

def __init__(
self,
value: jax.Array | Any | None = None
):
def __init__(self, value: jax.Array | Any | None = None):
"""
_BaseParam constructor.
Expand All @@ -90,11 +79,11 @@ def __init__(

@abc.abstractmethod
def get(self):
raise NotImplemented
raise NotImplementedError()

@abc.abstractmethod
def set(self, value):
raise NotImplemented
raise NotImplementedError()

def __bool__(self):
raise TypeError(
Expand All @@ -103,7 +92,7 @@ def __bool__(self):
)


# Parameter #############################################################################################################
# Parameter ###########################################################################################################


class DynamicParam(BaseParam):
Expand All @@ -119,121 +108,169 @@ class Param(DynamicParam):
def get(self) -> jax.Array:
return self._value

def set(self, value: jax.Array) -> 'Param':
def set(self, value: jax.Array) -> "Param":
self._value = value

return self

def __repr__(self):
rvalue = f"[{','.join(map(str, self.shape))}], {self.dtype}" if isinstance(self._value, jax.Array) else repr(self._value)
rvalue = (
f"[{','.join(map(str, self.shape))}], {self.dtype}"
if isinstance(self._value, jax.Array)
else repr(self._value)
)
t = f"{self.__class__.__name__}({rvalue})"

return t

# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
def __neg__(self):
return self._value.__neg__() # noqa: E704

def __pos__(self):
return self._value.__pos__() # noqa: E704

def __abs__(self):
return self._value.__abs__() # noqa: E704

def __invert__(self):
return self._value.__invert__() # noqa: E704

def __eq__(self, __other):
return self._value.__eq__(get(__other)) # noqa: E704

def __ne__(self, __other):
return self._value.__ne__(get(__other)) # noqa: E704

def __lt__(self, __other):
return self._value.__lt__(get(__other)) # noqa: E704

def __le__(self, __other):
return self._value.__le__(get(__other)) # noqa: E704

def __gt__(self, __other):
return self._value.__gt__(get(__other)) # noqa: E704

def __ge__(self, __other):
return self._value.__ge__(get(__other)) # noqa: E704

def __add__(self, __other):
return self._value.__add__(get(__other)) # noqa: E704

def __radd__(self, __other):
return self._value.__radd__(get(__other)) # noqa: E704

def __iadd__(self, __other):
self._value = self._value.__add__(get(__other)) # noqa: E704
return self

def __sub__(self, __other):
return self._value.__sub__(get(__other)) # noqa: E704

def __rsub__(self, __other):
return self._value.__rsub__(get(__other)) # noqa: E704

def __isub__(self, __other):
self._value = self._value.__sub__(get(__other)) # noqa: E704
return self

def __mul__(self, __other):
return self._value.__mul__(get(__other)) # noqa: E704

def __rmul__(self, __other):
return self._value.__rmul__(get(__other)) # noqa: E704

def __imul__(self, __other):
self._value = self._value.__mul__(get(__other)) # noqa: E704
return self

def __div__(self, __other):
return self._value.__div__(get(__other)) # noqa: E704

def __rdiv__(self, __other):
return self._value.__rdiv__(get(__other)) # noqa: E704

def __idiv__(self, __other):
self._value = self._value.__div__(get(__other)) # noqa: E704
return self

def __truediv__(self, __other):
return self._value.__truediv__(get(__other)) # noqa: E704

def __rtruediv__(self, __other):
return self._value.__rtruediv__(get(__other)) # noqa: E704

def __floordiv__(self, __other):
return self._value.__floordiv__(get(__other)) # noqa: E704

def __rfloordiv__(self, __other):
return self._value.__rfloordiv__(get(__other)) # noqa: E704

def __divmod__(self, __other):
return self._value.__divmod__(get(__other)) # noqa: E704

def __rdivmod__(self, __other):
return self._value.__rdivmod__(get(__other)) # noqa: E704

def __mod__(self, __other):
return self._value.__mod__(get(__other)) # noqa: E704

def __rmod__(self, __other):
return self._value.__rmod__(get(__other)) # noqa: E704

def __pow__(self, __other):
return self._value.__pow__(get(__other)) # noqa: E704

def __rpow__(self, __other):
return self._value.__rpow__(get(__other)) # noqa: E704

def __matmul__(self, __other):
return self._value.__matmul__(get(__other)) # noqa: E704

def __rmatmul__(self, __other):
return self._value.__rmatmul__(get(__other)) # noqa: E704

def __and__(self, __other):
return self._value.__and__(get(__other)) # noqa: E704

def __rand__(self, __other):
return self._value.__rand__(get(__other)) # noqa: E704

def __or__(self, __other):
return self._value.__or__(get(__other)) # noqa: E704

def __ror__(self, __other):
return self._value.__ror__(get(__other)) # noqa: E704

def __xor__(self, __other):
return self._value.__xor__(get(__other)) # noqa: E704

def __rxor__(self, __other):
return self._value.__rxor__(get(__other)) # noqa: E704

def __lshift__(self, __other):
return self._value.__lshift__(get(__other)) # noqa: E704

def __rlshift__(self, __other):
return self._value.__rlshift__(get(__other)) # noqa: E704

def __rshift__(self, __other):
return self._value.__rshift__(get(__other)) # noqa: E704

def __rrshift__(self, __other):
return self._value.__rrshift__(get(__other)) # noqa: E704

def __round__(self, ndigits=None):
return self._value.__round__(ndigits) # noqa: E704

def __getitem__(self, __idx):
return self._value.__getitem__(__idx)

def __array__(self, dtype=None):
return self._value.__array__(dtype)

def __getattr__(self, __name):
return getattr(self._value, __name)

Expand All @@ -254,7 +291,7 @@ def ndim(self):


class ParamDict(DynamicParam):
def __init__(self, value: Dict[str, jax.Array | Any | None] = {}):
def __init__(self, value: Dict[str, jax.Array | Any | None] = None):
super().__init__(value)

def __getitem__(self, __key: str) -> Any:
Expand All @@ -266,17 +303,13 @@ def __setitem__(self, __key: str, __value: jax.Array) -> None:
if self._value is None:
self._value = {}
self._value[__key] = __value

def __contains__(self, __key: str) -> bool:
return __key in self._value

def get(self, key: str | None = None, default: jax.Array | Any | None = None) -> Any:
return (
self._value.get(key, default)
if key is not None
else self._value
)

return self._value.get(key, default) if key is not None else self._value

def set(self, value) -> None:
self._value = value

Expand All @@ -288,15 +321,16 @@ class ParamCache:
"""
A simple sentinel class used to identify all parameters used as a temporary cache.
"""

pass


# Utils ################################################################################################################


def get(x: Any | BaseParam) -> Any:
"""Return the value encapsulated in the input argument if it is a BaseParam, otherwise return the input argument itself.
Used in ambiguous situations to ensure that the input is a value and not a BaseParam.
"""Return the value encapsulated in the input argument if it is a BaseParam, otherwise return the input argument
itself. Used in ambiguous situations to ensure that the input is a value and not a BaseParam.
Args:
x (Any | BaseParam): input argument
Expand Down
2 changes: 1 addition & 1 deletion pcax/predictive_coding/_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class VodeParam(Param):
class Cache(ParamDict, ParamCache):
def __init__(self, params: Dict[str, jax.Array] = {}):
def __init__(self, params: Dict[str, jax.Array] = None):
super().__init__(params)

def __init__(
Expand Down

0 comments on commit 4e3ac11

Please sign in to comment.