Skip to content

Official PyTorch Implementation of the Longhorn Deep State Space Model

Notifications You must be signed in to change notification settings

Cranial-XIX/longhorn

Repository files navigation

This repo contains the official PyTorch implementation of the Longhorn sequence modeling architecture.

Figure 1

Main Insight: The recurrent form of SSMs can be viewed as solving an online learning problem.

We believe the self-attention layer in the Transformer is performing associative recall (AR). For instance, the model observes a stream of (k, v) pairs. At test time, it is provided with a key (k) and is asked to retrieve its corresponding value (v). The AR problem is like an online prediction problem, where the key is the input and the value is the label. Based on this insight, we make a parallelizable RNN, named Longhorn, whose per-token update explicitly solves this online prediction problem in closed form.

Figure 2

Main Observation: Longhorn (1.3B), when trained on 100B tokens on the SlimPajama dataset, achieves 1.8x better sample efficiency than Mamba. This means that it achieves the same average validation perplexity as Mamba across 8 downstream benchmarks using about half the data. In addition, Longhorn successfully extrapolates up to 16x of its training context length.


For the convenience of the research community, we also provide the architectures for:

The codebase is adapted from the nanoGPT project by Andrej Karpathy.

1. Installtion

Please run the following commands to install the relevant dependencies:

conda create -n longhorn python=3.10.9
conda activate longhorn
git clone https://github.com/Cranial-XIX/longhorn.git
cd longhorn
pip install -r requirements.txt
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121  # need at least CUDA12
pip install triton==2.2.0  # for GLA

Then install the CUDA kernel for both Longhorn and Mamba

git clone https://github.com/Cranial-XIX/longhorn_cuda.git
cd longhorn_cuda
chmod +x rebuild.sh && bash rebuild.sh

2. Test Run

You can then test if the dependency has been successfully by training on Shakespeare dataset.

First, prepare the dataset

cd data/shakespeare_char
python prepare.py
cd ../..

Then run with:

bash test_shakespeare_run.sh MODEL

where MODEL can be any of llama, retnet, rwkv, gla, mamba, longhorn.

3. Train on OpenWebText

Then you can run on the OpenWebText dataset, assuming you have 4x 80G A100 GPUs.

Again, prepare the dataset

cd data/openwebtext
python prepare.py
cd ../..

Then run with:

bash run.sh MODEL

where MODEL can be any of llama, retnet, rwkv, gla, mamba, longhorn.

4. Citation

If you find longhorn or this repo to be useful, please consider citing our paper:

@misc{liu2024longhornstatespacemodels,
      title={Longhorn: State Space Models are Amortized Online Learners},
      author={Bo Liu and Rui Wang and Lemeng Wu and Yihao Feng and Peter Stone and Qiang Liu},
      year={2024},
      eprint={2407.14207},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.14207},
}

About

Official PyTorch Implementation of the Longhorn Deep State Space Model

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published