RWKV-2 is a RNN with Transformer-level performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly computer the hidden state for the "RNN" mode.
So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding (using the final hidden state).
Reddit discussion: https://www.reddit.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/
Tweet from Sepp Hochreiter (thank you!): https://twitter.com/HochreiterSepp/status/1524270961314484227
You can find me (BlinkDL) in the EleutherAI Discord: https://www.eleuther.ai/get-involved/
I am looking for CUDA gurus to optimize the kernel :) Please contact me if you are interested. Thank you.
User feedback:
I've so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good - similar ppl to models taking much, much longer to train.
dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.
I am training RWKV-2 on the Pile (https://github.com/BlinkDL/RWKV-v2-RNN-Pile):
All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, and I believe you can run a 1.5B params RWKV-v2-RNN with reasonable speed on your phone.
Check https://github.com/BlinkDL/RWKV-v2-RNN-Pile for L24-D1024 and L12-D768 models trained on the Pile (and the latest code). It's very fast on CPU (the default mode).
Read the inference code in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/blob/main/src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks (probably you shall begin with .xx and .aa/.bb (.aa divided by .bb)).
See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN.
You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
Training: https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN
You will be training the "GPT" version because it's paralleziable and faster to train. I find RWKV-2 can extrapolate, so training with ctxLen 768 can work for ctxLen of several thousand. You can fine-tune the model with longer ctxLen later and it can quickly adapt to longer ctxLens.
My LR schedule for the L24-D1024 RWKV-2:
Fixing NaN or loss spikes: load a previous checkpoint, decrease LR a bit. I find you can decrease the LR faster than GPT, and eventually to 1/50 of LR_max.
UPDATE: Search for "RWKV v2+" here and change RWKV-2 to PreLN to make it more stable.
Fine-tuning: see https://github.com/BlinkDL/RWKV-v2-RNN-Pile.
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
And it's also using a number of my tricks, such as:
-
SmallInitEmb: https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers) which helps the embedding quality, and stabilizes Post-LN (which is what I am using).
-
Token-shift: https://github.com/BlinkDL/RWKV-LM#token-shift-time-shift-mixing (applicable to all transformers), especially helpful for char-level models.
-
Head-QK: https://github.com/BlinkDL/RWKV-LM#the-head-qk-trick-learning-to-copy-and-avoid-tokens (applicable to all transformers). Note: it's helpful, but I disabled it in the Pile model to keep it 100% RNN.
-
Extra R-gate in the FFN (applicable to all transformers). I am also using reluSquared from Primer.
-
Better initilization: I init most of the matrices to ZERO (see RWKV_Init in https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py).
-
You can transfer some parameters from a small model to a large model, for faster and better convergence (see https://www.reddit.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/).
-
My CUDA kernel: https://github.com/BlinkDL/RWKV-CUDA to speedup training.
The a b c d factors work together to build a time-decay curve: X, 1, W, W^2, W^3, ...
Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
- a and b: EMAs of kv and k.
- c and d: these are a and b combined with "self-attention".
kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.
RWKV v2 is parallelizable because the time-decay of each channel is data-independent (and trainable). For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV v2 you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect.
Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
Use preLN instead of postLN (more stable & faster convergence):
if self.layer_id == 0:
x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
I need a better CUDA kernel to (1) pull off maxK so there's need to clamp k to 60. (2) fix divide-by-zero without using K_EPS. (3) support bf16/fp16. Please let me know if you are a CUDA expert :)
Removing the maxK limitation will also make it easy to clean the state of a KV-V channel, by using a huge K.
Let F[t] be the system state at t.
Let x[t] be the new external input at t.
In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.
The simplified formula for GPT:
It's very capable in theory, however that does not mean we can fully utilize its capability with usual optimizers. I suspect the loss landscape is too difficult for our current methods.
Compare with the simplified formula for RWKV-2 (the parallel mode, looks similar to Apple's AFT):
The R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).
In GPT, the contribution of F[i] to F[t+1] is weighted by .
In RWKV-2, the contribution of F[i] to F[t+1] is weighted by .
- The
is a non-linearity and we can use sigmoid.
- Note
is not in the denominator, and I call R the "receptance".
- The
is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).
Here comes the punchline: we can rewrite it into a RNN (recursive formula). Note:
Therefore it's straightforward to verify:
where A[t] and B[t] are the numerator and denominator of the previous step, respectively.
I believe RWKV-2 is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.
Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.
I am using a trick to sample the Pile deterministically yet randomly enough.
Let's say the pile has x chunks (a chunk = ctx_len tokens).
pick a prime number p just less than x, and make sure p = 2 (mod 3).
Use (step * step * step) mod p to sample it. Add some bias to step for extra randomness.
We propose a new sampling method called top-p-x:
it's like top-p, and the only difference is you also keep all tokens whose prob > x.
Try x = 0.01 first.
I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.
UPDATE: In "Conclusion 1.", use the best-fitting regime (ignore the initial steps where our approximations break down) to fit the parameters.
Try this: fixed lr for 1 hr, then exponential decay to 0.2 * lr in 12 hrs, and choose the t=[1hr, 13hr] segment.
In the last three plots, black = predicted loss curve of the new LR schedule, blue = original (unoptimized) real loss curve, orange = new LR schedule.
We propose the RWKV language model, with alternating time-mix and channel-mix layers:
-
The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.
-
The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.
(1) We changed the normalization (denominator). For masked language models, we define:
(UPDATE: We are using the original AFT normalization in v2)
Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):
Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.
(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)
-
The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence.
-
Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned).
The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.
However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).
My theory on the effectiveness of token-shift:
When we train a GPT, the hidden representation of a token has to accomplish two different objects:
-
Predict the next token. Sometimes this is easy (obvious next token).
-
Collect all previous context info, so later tokens can use it. This is always hard.
The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.
You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.
p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)
In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.
q = self.head_q(x)[:,:T,:] # projecting to 256-d
k = self.head_k(x)[:,:T,:] # projecting to 256-d
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = self.head(x) + c
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).
We also propose a new sampling method called top-a (as in src/utils.py):
(1) Find the max probability p_max after softmax.
(2) Remove all entries whose probability is lower than 0.2 * pow(p_max, 2). So it's adaptive, hence "top-a".
(3) Feel free to tune the 0.2 and 2 factor. Tune 0.2 first.
The idea of top-a:
- If max_prob=0.9, then remove all tokens with prob < 0.162 (so, removing all alternatives)
- If max_prob=0.5, then remove all tokens with prob < 0.05 (so, allowing more choices)
- If max_prob=0.1, then remove all tokens with prob < 0.002 (so, allowing lots of possibilities)
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')
Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
Gray: usual MHA+Rotary+GeGLU - performance not as good. 17.2M params.
Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance. 16.6M params.
Green: MHA+Rotary+GeGLU+Token_shift. 17.2M params.
Blue: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance. 16.6M params.
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196577},
url = {https://doi.org/10.5281/zenodo.5196577}
}
We use careful initialization for RWKV to get fast convergence - orthogonal matrices with proper scaling, and special time_w curves. Check model.py for details.
Some learned time_w examples: