From 58653b8c39205736610242a093b7be9fd9856ae5 Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Sun, 10 Apr 2022 19:12:59 -0400 Subject: [PATCH 1/3] Bump README --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 75fcf89..26bcae5 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,13 @@ Gets "best" 97.76% accuracy in 10 epochs @ 40s/epoch on a TitanRTX. ``` # Following @frederick0329's/@albertgu's results: https://github.com/srush/annotated-s4/pull/43#issuecomment-1065444261 python -m s4.train --dataset cifar-classification --model s4 --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 5e-3 --d_model 512 + +# DSS Model +python -m s4.train --dataset cifar-classification --model dss --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 5e-3 --d_model 512 ``` -Gets "best" 85.81% accuracy after 100 epochs @ 3m8s/epoch on a TitanRTX +S4 gets "best" 87.05% accuracy after 100 epochs @ 3m8s/epoch on a TitanRTX +DSS gets "best" 88.90% accuracy after 100 epochs @ 3m11s/epoch on a TitanRTX --- From 740fc01220ec337d93a895184a3e71885e86abcd Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Mon, 11 Apr 2022 14:58:59 -0400 Subject: [PATCH 2/3] Add outline for DSS post --- s4/dss.py | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++---- s4/s4.py | 2 +- 2 files changed, 117 insertions(+), 8 deletions(-) diff --git a/s4/dss.py b/s4/dss.py index fe43061..44ec7dd 100644 --- a/s4/dss.py +++ b/s4/dss.py @@ -1,14 +1,124 @@ -import s4.s4 as s4 - +#

The Diagonal State Space Model

+# +# +#
+#

Diagonal State Spaces are as Effective as Structured State Spaces

+#
+# +#
+#

Ankit Gupta

+# +# --- +# +# *Note: This page is meant as a standalone complement to Section 2 [TODO Link] of the original +# blog post.* +# +# The months following the release of S4 paper by Gu et. al. were characterized by a wave of excitement around the new +# model, it's ability to handle extremely long sequences, and generally, what such a departure from Transformer-based +# architectures could mean. The original authors came out with a +# [follow-up paper applying S4 to audio generation](https://arxiv.org/abs/2202.09729), and weeks later, a completely +# [different group applied S4 to long-range movie clip classification](https://arxiv.org/abs/2204.01692). +# +# Yet, it remains hard to parse aspects of the implementation, especially the derivation of the diagonal plus low rank +# constraint on $\boldsymbol{A}$. Not only was this math fairly complex, but in code, required the use of custom CUDA +# kernels -- further obfuscating the implementation (and why this blog uses Jax to efficiently compile the relevant +# operations). +# +# However, at the end of March 2022 -- an alternative construction for state space models was proposed in [Diagonal +# State Spaces are as Effective as Structured State Spaces](https://arxiv.org/abs/2203.14343). This short paper derives +# an alternative construction of learnable state space models that is both 1) simple, 2) requires no custom kernels, and +# 3) can be efficiently implemented in Jax or PyTorch in just a dozen lines. The rest of this post steps through this +# alternative derivation, **a complete standalone for Section 2** of the original Annotated S4 post. +# +# We'll still be using Jax with the Flax NN Library for consistency with the original post, though this Diagonal State +# Space (DSS) variant can be easily implemented in PyTorch with some minor changes. + +# import s4.s4 as s4 TODO -- For some reason breaks streamlit... +import s4 from functools import partial import jax import jax.numpy as np from flax import linen as nn from jax.nn.initializers import lecun_normal -from jax.numpy.linalg import eig rng = jax.random.PRNGKey(1) +# ## Table of Contents +#