Transformers Can Do Arithmetic with the Right Embeddings! Link to arXiv paper
A joint project by: Sean McLeish, Arpit Bansal, Alex Stein, Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild and Tom Goldstein
This repository contains code to replicate our research. It is a fork of the language model training framework cramming edited to for a next token prediction objective.
We provide a standalone implementation of Abacus Embeddings in abacus.py.
To cite our work, please use this bibtex.
@article{mcleish2024transformers,
title={Transformers Can Do Arithmetic with the Right Embeddings},
author={Sean McLeish and Arpit Bansal and Alex Stein and Neel Jain and John Kirchenbauer and Brian R. Bartoldson and Bhavya Kailkhura and Abhinav Bhatele and Jonas Geiping and Avi Schwarzschild and Tom Goldstein},
journal={arXiv preprint arXiv:2405.17399},
year={2024}
}
We developed in Python 3.10.4, to install run:
git clone [email protected]:mcleish7/arithmetic.git
cd arithmetic
pip install .
On some machines you will need to run:
pip install multiprocess -U
pip install dill -U
pip install apache-beam -U
We release our datasets on Google Drive both in zipped format. We recommend you work with the zipped version until it is correctly placed in your file system.
Alternatively, you can make your own datasets using create_data_split.py using the commands from shells/generate_and_tokenize_data.sh.
We recommend creating another directory cramming-data
inside of arithmetic. This is where the models, logs and data will be stored.
You can either export you cramming base directory path to your .bashrc
or you can replace $cramming_base_dir
manually in the provided shells.
cd arithmetic
mkdir cramming-data
echo 'export cramming_base_dir=MY_BASE_DIR' >> ~/.bashrc
source ~/.bashrc
For example, this may look like: echo 'export cramming_base_dir=~/arithmetic/cramming-data' >> ~/.bashrc
For example our file system looks like:
cramming-generative
└── cramming-data
├── addition-train-one
│ ├── pretrain/<DATE>/<TIME>
│ │ ├── .hydra
│ │ │ ├── config.yaml
│ │ │ ├── hydra.yaml
│ │ │ └── overrides.yaml
│ │ └── addition-train-one_pretrain.log
│ ├── checkpoints/FINAL_<LOSS_VAL>
│ │ ├── model_config.json
│ │ ├── model.safetensors
│ │ └── state_dict.pth
│ └── downstream
└── data
└── arithmetic_data
├── +_grid_eval_dataset_reverse_all_tokenized
└── ... other datasets ...
Example commands are in the shells directory, organised by task.
- Give samples instead of tokens equal importance in loss:
arch.loss_reduction=none
- Divide the gradients in the recurrent block by the number of recurrences:
arch.throttle=True
- Mask before the equals sign:
arch.mask_before_equals=True
- Skip connections inside of the recurrent block:
arch.forward_only_model_with_skip=True
- Multi-GPU:
python
->torchrun --nproc_per_node=<NUM GPUS> --standalone
and addimpl.fullgraph=false
- Learned:
arch.embedding.pos_embedding=learned
- Abacus:
arch.embedding.pos_embedding=abacus
- If you want the maximum k in abacus to be larger:
arch.embedding.max_abacus_len=100
, be default this value is 100. Abacus is also implemented in a standalone manner in abacus.py.
- NoPE:
arch.embedding.pos_embedding=None
- FIRE:
arch.embedding.pos_embedding=None arch.attention.type="self-attention" arch.attention.rotary_embedding="fire"
- FIRE randomised: e.g:
arch.embedding.pos_embedding=None arch.attention.type="self-attention" arch.attention.rotary_embedding="fire" arch.attention.max_length=128
by defaultarch.attention.max_length=0
so setting this longer than the max sequence length gives some randomness in the embedding. - RoPE:
arch.attention.type="self-attention" arch.attention.rotary_embedding=true
We have implemented single GPU training checkpointing, to do this use:
impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last'
This saves a checkpoint every 60 minutes under the name 'last'
Caution: This feature is not fully tested for multi-GPU cases. We also cannot currently train models which have used their full budget for longer.
You can log runs to your weights&biases account. To do so, simply modify wandb.entity
and wandb.project
on the command line or at cramming/config/wandb/default.yaml.
We show examples in shells/evaluation.sh.
We provide a very basic automation in gen_eval_script.py, this prints the basic commands you may need to further edit these.
For addition we have a very large possible evaluation set, we do a grid search over a 100x100 grid which we split into 20 pieces with the aim of balancing the number of forward calls across all 20 pieces. We then have a further eval for operand lengths 100->160.
We only evaluate up to 25x25, which we do in a single job.
Sorting uses a separate evaluation file sort_eval.py, this is because the evaluation calls cannot be parallelised, making evaluation much longer.
The evaluation cannot be parallelised because the place of the equals sign is not fixed for a batch.
We currently evaluate across 30 jobs for a 30x30 grid but this can be reduced to a smaller number of jobs using these flags: max_size_given, start_ind_1_given, start_ind_2_given
We use the same framework as for addition but the process is quicker as some of the batches do not contain 100 samples as there are not 100 possibilities for some batches. Unlike addition we do not sample with replacement for this task.
- We provide pretty_plotter.py to combine the small evaluation grids together into one plot.
Use this by putting the model name into the string at the top of the
main
function. - For the large 100x100 grids we provide pretty_plotter_big.py. These are designed to be as flexible as possible but may need to be edited to fit your file set up.
- For sorting, we provide pretty_plotter_sort.py, this allows us to read the individual
.txt
files created during testing and merge them all together into a nice plot.
Please, feel free to contact us with any questions, or open an issue on Github.