"Stitching together a monstrous, efficient LLM!"
This projects is an attempt at optimizing the llama2-7b-chat
transformer model by selectively skipping and repeating attention blocks. This process involves altering the existing Llama codebase (HugginFace transformer's modeling_llama.py
module in particular) and analyzing the resulting attention maps, as well as the Fourier Transform on the attention maps.
Simply install create a python environment with and install the required dependencies:
python -m venv .venv
- On Linux:
source .venv/bin/activate
2b. On Windows:.venv/Scripts/activate
pip install -r requirements.txt
The playground.ipynb
notebook contains a walkthrough and description of the whole project, including an sample of the Fourier Transform of the baseline model and a quick analysis on the attention maps.
These scripts can are the one I used to tests the various configurations of the model. Each one tests the configurations on different datasets, or generates a specific output artifact (e.g. attention maps).
To use them you'll first need to activate the environment you just created before, and run one of them from the root folder.
On Linux:
source .venv/bin/activate
python scripts/['name of the script'].py
On Windows:
.venv/Scripts/activate
python.exe scripts/['name of the script'].py
A total of 25 configurations were first tested with a preliminary qualitative evaluation, which consisted in completing the sentence "Once upon a time".
However, only 6 configurations were ultimately chosen for further testing on the HellaSwag dataset.
A description of some of the configurations used can be found in docs/configuration.md
.
The results vary a lot from configuration to configuration. However, it can be quickly noted that:
- Introducing repeats increased the likelihood of the model to generate gibberish and lose its logical capabilities.
- Skipping layers can cause execution time to be higher.
Here's a sample of the gibberish generated by some of the configurations:
Configuration | Generated text |
---|---|
0-7 |
Once upon a timezetempreasacondarichte?? trickster goddess pue moonkennecticut [..] Reserveikaiwitzter PetersburgovPortail [..] |
all_except_last_two |
Once upon a time year0 **stadt [..] Death it Yearwaltapk Progress R?f?rencePU. ??? [..] |
only_even_layers |
Once upon a time S??. R S l d d? S S S S S [..] |
first_last_2_with_skips |
Once upon a time, in the midst of a busy schedulesomeone's attention was caught.?You the world and its of the, and and and [..] |
first_last_8r2 |
Once upon a time in?ceycofortia-inaymskoe Bridge---Monlinaiticky'830 [..] |
15r3_23r3_31r3 |
Once upon a timepus pri rosgemeingemeingemeinwach junigemeingemei [..] |
Each '?' is a non-ASCII encoded character. Here's a brief description of these configs:
0-7
: uses only layers 0 to 7all_except_last_two
: uses all layers except the last 2.only_even_layers
: uses only even layers.first_last_2_with_skips
: repeats layers 0 and 31, skips [15, 16, 24, 25]first_last_8r2
: uses only the first and last 8 layers, and repeats them twice.15r3_23r3_31r3
: only uses layers 15,23,31 and repeats them 3 times. The craziest one if you ask me.
It's interesting to see how some phrases still contain some words that are somewhat semantically coherent with the starting text (e.g. they are related with time, characters from tales etc..). Some of them are just straight up yapping though.
These are the results on 50 samples of the HellaSwag dataset:
Configuration | HellaSwag | Avg. exec. time |
---|---|---|
baseline |
0.34 | 91.1 s |
0-23_27-31 |
0.38 | 81.2 s |
15_single_skip |
0.38 | 95.8 s |
mid_expansion_with_repeats |
0.22 | 68.8 s |
2_3rd_of_llama |
0.26 | 95.4 s |
2_3rds_plus_llama |
0.30 | 102.7 s |
skip_near_end_keep_last_two |
0.42 | 79.0 s |
Configs:
baseline
: the untouchedllama2-7b-chat
from HF.0-23_27-31
: skips layers from 24 to 26.15_single_skip
: skips only layer 15.mid_expansion_with_repeats
: skips layers [6, 7, 8, 9, 25, 26, 27, 28], repeats layers from 14 to 19 twice.2_3rd_of_llama
: skips layers from 11 to 20.2_3rds_plus_llama
: skips only odd indexed layers from 11 to 20.skip_near_end_keep_last_two
: skips layers from 27 to 29.
Here's an example of a Fourier Transform on the attention maps of the model skip_near_end_keep_last_two
on the input 'Once upon a time':
- Since one of the ultimate goals of the project is to reduce the memory footprint of the model, it will be nice to experiment kV cache compression. In a future extension of the project, this will probably be one of focusing points.