Skip to content

Commit

Permalink
feat: wip training log
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Apr 13, 2023
1 parent 1280edd commit b170eb9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions TRAINING_LOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,46 @@ Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.

We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.


## GPT-J Training

### Model Training Divergence

We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. ![](figs/overfit-gpt-j.png). We release the checkpoint after epoch 1.


Using Atlas, we extracted the embeddings and calculated the per sequence level loss. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` ![](figs/clustering_overfit.png)



### GPT4All-J Hyperparameters

We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.



The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:

| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 32 |
| Global BS | 256 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |


The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:

| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |
Binary file added figs/clustering_overfit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/overfit-gpt-j.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b170eb9

Please sign in to comment.