Skip to content

Making Llama2 more computationally and memory efficient by skipping and repeating multiple attention layers.

Notifications You must be signed in to change notification settings

AngeloGalav/franken-llama

Repository files navigation

franken-Llama

Frankenstein Llama

"Stitching together a monstrous, efficient LLM!"

This projects is an attempt at optimizing the llama2-7b-chat transformer model by selectively skipping and repeating attention blocks. This process involves altering the existing Llama codebase (HugginFace transformer's modeling_llama.py module in particular) and analyzing the resulting attention maps, as well as the Fourier Transform on the attention maps.

Setup

Simply install create a python environment with and install the required dependencies:

  1. python -m venv .venv
  2. On Linux: source .venv/bin/activate 2b. On Windows: .venv/Scripts/activate
  3. pip install -r requirements.txt

Notebook

The playground.ipynb notebook contains a walkthrough and description of the whole project, including an sample of the Fourier Transform of the baseline model and a quick analysis on the attention maps.

Batch scripts

These scripts can are the one I used to tests the various configurations of the model. Each one tests the configurations on different datasets, or generates a specific output artifact (e.g. attention maps).

To use them you'll first need to activate the environment you just created before, and run one of them from the root folder.

On Linux:

source .venv/bin/activate
python scripts/['name of the script'].py

On Windows:

.venv/Scripts/activate
python.exe scripts/['name of the script'].py

Results

A total of 25 configurations were first tested with a preliminary qualitative evaluation, which consisted in completing the sentence "Once upon a time". However, only 6 configurations were ultimately chosen for further testing on the HellaSwag dataset. A description of some of the configurations used can be found in docs/configuration.md.

The results vary a lot from configuration to configuration. However, it can be quickly noted that:

  • Introducing repeats increased the likelihood of the model to generate gibberish and lose its logical capabilities.
  • Skipping layers can cause execution time to be higher.

Generated text

Here's a sample of the gibberish generated by some of the configurations:

Configuration Generated text
0-7 Once upon a timezetempreasacondarichte?? trickster goddess pue moonkennecticut [..] Reserveikaiwitzter PetersburgovPortail [..]
all_except_last_two Once upon a time year0 **stadt [..] Death it Yearwaltapk Progress R?f?rencePU. ??? [..]
only_even_layers Once upon a time S??. R S l d d? S S S S S [..]
first_last_2_with_skips Once upon a time, in the midst of a busy schedulesomeone's attention was caught.?You the world and its of the, and and and [..]
first_last_8r2 Once upon a time in?ceycofortia-inaymskoe Bridge---Monlinaiticky'830 [..]
15r3_23r3_31r3 Once upon a timepus pri rosgemeingemeingemeinwach junigemeingemei [..]

Each '?' is a non-ASCII encoded character. Here's a brief description of these configs:

  • 0-7: uses only layers 0 to 7
  • all_except_last_two: uses all layers except the last 2.
  • only_even_layers: uses only even layers.
  • first_last_2_with_skips: repeats layers 0 and 31, skips [15, 16, 24, 25]
  • first_last_8r2: uses only the first and last 8 layers, and repeats them twice.
  • 15r3_23r3_31r3: only uses layers 15,23,31 and repeats them 3 times. The craziest one if you ask me.

It's interesting to see how some phrases still contain some words that are somewhat semantically coherent with the starting text (e.g. they are related with time, characters from tales etc..). Some of them are just straight up yapping though.

HellaSwag Dataset

These are the results on 50 samples of the HellaSwag dataset:

Configuration HellaSwag Avg. exec. time
baseline 0.34 91.1 s
0-23_27-31 0.38 81.2 s
15_single_skip 0.38 95.8 s
mid_expansion_with_repeats 0.22 68.8 s
2_3rd_of_llama 0.26 95.4 s
2_3rds_plus_llama 0.30 102.7 s
skip_near_end_keep_last_two 0.42 79.0 s

Configs:

  • baseline: the untouched llama2-7b-chat from HF.
  • 0-23_27-31: skips layers from 24 to 26.
  • 15_single_skip: skips only layer 15.
  • mid_expansion_with_repeats: skips layers [6, 7, 8, 9, 25, 26, 27, 28], repeats layers from 14 to 19 twice.
  • 2_3rd_of_llama: skips layers from 11 to 20.
  • 2_3rds_plus_llama: skips only odd indexed layers from 11 to 20.
  • skip_near_end_keep_last_two: skips layers from 27 to 29.

Fourier transform example

Here's an example of a Fourier Transform on the attention maps of the model skip_near_end_keep_last_two on the input 'Once upon a time':

Fourier transform

Future work

  • Since one of the ultimate goals of the project is to reduce the memory footprint of the model, it will be nice to experiment kV cache compression. In a future extension of the project, this will probably be one of focusing points.

About

Making Llama2 more computationally and memory efficient by skipping and repeating multiple attention layers.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published