Skip to content

Commit

Permalink
Merge pull request #67 from lettercode/lettercode/upgrade-to-keras-v3
Browse files Browse the repository at this point in the history
Add Keras v3 implementation
  • Loading branch information
mlech26l authored Jun 19, 2024
2 parents 2b6fbdf + e39620e commit 4d3414d
Show file tree
Hide file tree
Showing 11 changed files with 1,399 additions and 14 deletions.
48 changes: 40 additions & 8 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ permissions:
contents: read

jobs:
build:

build_pytorch_backend:
runs-on: ubuntu-latest

container:
image: pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
env:
KERAS_BACKEND: torch
volumes:
- my_docker_volume:/volume_mount

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -35,6 +38,35 @@ jobs:
echo "PYTHONPATH=." >> $GITHUB_ENV
- name: Test with pytest
run: |
pytest ncps/tests/test_tf.py
pytest ncps/tests/test_torch.py
pytest ncps/tests/test_keras.py
build_tensorflow_backend:
runs-on: ubuntu-latest

container:
image: tensorflow/tensorflow:2.16.1
env:
KERAS_BACKEND: tensorflow
volumes:
- my_docker_volume:/volume_mount

steps:
- uses: actions/checkout@v3

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: set pythonpath
run: |
echo "PYTHONPATH=." >> $GITHUB_ENV
- name: Test with pytest
run: |
pytest ncps/tests/test_keras.py
44 changes: 44 additions & 0 deletions ncps/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2020-2021 Mathias Lechner
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import absolute_import

from .ltc_cell import LTCCell
from .mm_rnn import MixedMemoryRNN
from .cfc_cell import CfCCell
from .wired_cfc_cell import WiredCfCCell
from .cfc import CfC
from .ltc import LTC
from packaging.version import parse

try:
import keras
except:
raise ImportWarning(
"It seems like the Keras package is not installed\n"
"Please run"
"`$ pip install keras`. \n",
)

if parse(keras.__version__) < parse("3.0.0"):
raise ImportError(
"The Keras package version needs to be at least 3.0.0 \n"
"for ncps-keras to run. Currently, your Keras version is \n"
"{version}. Please upgrade with \n"
"`$ pip install --upgrade keras`. \n"
"You can use `pip freeze` to check afterwards that everything is "
"ok.".format(version=keras.__version__)
)
__all__ = ["CfC", "CfCCell", "LTC", "LTCCell", "MixedMemoryRNN", "WiredCfCCell"]
99 changes: 99 additions & 0 deletions ncps/keras/cfc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import keras

import ncps
from . import CfCCell, MixedMemoryRNN, WiredCfCCell


@keras.utils.register_keras_serializable(package="ncps", name="CfC")
class CfC(keras.layers.RNN):
def __init__(
self,
units: Union[int, ncps.wirings.Wiring],
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: int = None,
backbone_layers: int = None,
backbone_dropout: float = None,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
unroll: bool = False,
time_major: bool = False,
**kwargs,
):
"""Applies a `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ RNN to an input sequence.
Examples::
>>> from ncps.keras import CfC
>>>
>>> rnn = CfC(50)
>>> x = keras.random.uniform((2, 10, 20)) # (B,L,C)
>>> y = rnn(x)
:param units: Number of hidden units
:param mixed_memory: Whether to augment the RNN by a `memory-cell <https://arxiv.org/abs/2006.04418>`_ to help learn long-term dependencies in the data (default False)
:param mode: Either "default", "pure" (direct solution approximation), or "no_gate" (without second gate). (default "default)
:param activation: Activation function used in the backbone layers (default "lecun_tanh")
:param backbone_units: Number of hidden units in the backbone layer (default 128)
:param backbone_layers: Number of backbone layers (default 1)
:param backbone_dropout: Dropout rate in the backbone layers (default 0)
:param return_sequences: Whether to return the full sequence or just the last output (default False)
:param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False)
:param go_backwards: If True, the input sequence will be process from back to the front (default False)
:param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False)
:param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False)
:param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False)
:param kwargs:
"""

if isinstance(units, ncps.wirings.Wiring):
if backbone_units is not None:
raise ValueError(f"Cannot use backbone_units in wired mode")
if backbone_layers is not None:
raise ValueError(f"Cannot use backbone_layers in wired mode")
if backbone_dropout is not None:
raise ValueError(f"Cannot use backbone_dropout in wired mode")
cell = WiredCfCCell(units, mode=mode, activation=activation)
else:
backbone_units = 128 if backbone_units is None else backbone_units
backbone_layers = 1 if backbone_layers is None else backbone_layers
backbone_dropout = 0.0 if backbone_dropout is None else backbone_dropout
cell = CfCCell(
units,
mode=mode,
activation=activation,
backbone_units=backbone_units,
backbone_layers=backbone_layers,
backbone_dropout=backbone_dropout,
)
if mixed_memory:
cell = MixedMemoryRNN(cell)
super(CfC, self).__init__(
cell,
return_sequences,
return_state,
go_backwards,
stateful,
unroll,
time_major,
**kwargs,
)
Loading

0 comments on commit 4d3414d

Please sign in to comment.