Skip to content

Commit

Permalink
refactor how we treat datasets, because we're about to have more of t…
Browse files Browse the repository at this point in the history
…hem and we don't want them to clutter up root dir etc. this is only step 1, i'm about to refactor a bunch of the dataloading, how the .bin files work and are loaded, how the DataLoader works, etc. This is all needed to support good evals and training at scale
  • Loading branch information
karpathy committed May 20, 2024
1 parent 6c8bc17 commit 722e5b2
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 121 deletions.
30 changes: 15 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v4

- name: Install OpenMP
if: matrix.os != 'windows-latest'
if: matrix.os != 'windows-latest'
run: |
if [ "${{ runner.os }}" == "Linux" ]; then
sudo apt-get update && sudo apt-get install -y libomp-dev
Expand All @@ -33,7 +33,7 @@ jobs:
run: pip install -r requirements.txt

- name: Run preprocessing
run: python prepro_tinyshakespeare.py
run: python dev/data/tinyshakespeare.py

- name: Train model
run: python train_gpt2.py --device=cpu
Expand All @@ -45,9 +45,9 @@ jobs:
$url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip'
$output = './make-bin-win64.zip'
$wc.DownloadFile($url, $output)
- name: Unzip Win32 Makefile
if: matrix.os == 'windows-latest'
if: matrix.os == 'windows-latest'
run: |
unzip make-bin-win64.zip
Expand All @@ -59,26 +59,26 @@ jobs:
if: matrix.os == 'windows-latest'
shell: cmd
run: |
call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat"
call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat"
make-4.4.1\dist\make WIN_CI_BUILD=1 test_gpt2 train_gpt2
- name: Execute testing program (With OpenMP)
if: matrix.os != 'windows-latest'
if: matrix.os != 'windows-latest'
run: OMP_NUM_THREADS=8 ./test_gpt2

- name: Execute Windows testing program (With OpenMP)
if: matrix.os == 'windows-latest'
- name: Execute Windows testing program (With OpenMP)
if: matrix.os == 'windows-latest'
shell: cmd
run: |
copy test_gpt2 test_gpt2.exe
test_gpt2.exe
test_gpt2.exe
- name: Compile training and testing program without OpenMP
if: matrix.os != 'windows-latest'
if: matrix.os != 'windows-latest'
run: NO_OMP=1 make test_gpt2 train_gpt2

- name: Execute testing program (No OpenMP)
if: matrix.os != 'windows-latest'
if: matrix.os != 'windows-latest'
run: ./test_gpt2

build-cuda-windows:
Expand All @@ -93,11 +93,11 @@ jobs:
$url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip'
$output = './make-bin-win64.zip'
$wc.DownloadFile($url, $output)
- name: Unzip Win32 Makefile
run: |
unzip make-bin-win64.zip
- name: Install Cuda Toolkit 12.4 on Windows
run: |
mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4"
Expand Down Expand Up @@ -134,9 +134,9 @@ jobs:
shell: cmd
working-directory: ${{ github.workspace }}
run: |
call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat"
call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat"
make-4.4.1\dist\make -j WIN_CI_BUILD=1 train_gpt2fp32cu test_gpt2fp32cu test_gpt2cu train_gpt2cu profile_gpt2cu
build-cuda-fp32:
runs-on: ubuntu-latest
container:
Expand Down
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The "I don't care about anything I just want to train and I have a GPU" section.

```bash
pip install -r requirements.txt
python prepro_tinyshakespeare.py
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2fp32cu
./train_gpt2fp32cu
Expand All @@ -22,17 +22,17 @@ The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent

## quick start (GPU, fast bleeding edge)

I want to see it go fast. In this case switch to our mainline, most optimized `train_gpt2.cu` and also turn on flash attention. Run:
I want to see it go fast. In this case switch to our mainline, most optimized `train_gpt2.cu`. Run:

```bash
pip install -r requirements.txt
python prepro_tinyshakespeare.py
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2cu
./train_gpt2cu
```

If you additionally install cuDNN (see the CUDA section below), you can also go faster with flash attention
If you additionally install cuDNN (see the CUDA section below), you can go even faster with flash attention. Adjust the make command as follows to compile with cudnn / flash attention:

```bash
make train_gpt2cu USE_CUDNN=1
Expand All @@ -48,9 +48,9 @@ Note that the default batch size is very low (4). If you have enough memory on y
My standard "prod" run with a nice GPU (e.g. A100 40GB) actually trains on TinyStories instead of TinyShakespeare, and looks like this:

```bash
python prepro_tinystories.py
python dev/data/tinystories.py
make train_gpt2cu USE_CUDNN=1
./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -o stories.log -b 32
./train_gpt2cu -i dev/data/tinystories/TinyStories -v 250 -s 250 -g 144 -o stories.log -b 32
```

Where I decrease the frequency of validation loss and sampling to every 250 steps, sample 144 tokens during sampling stage (to fit ~one story), and at batch size 32.
Expand All @@ -61,7 +61,7 @@ The "I am so GPU poor that I don't even have one" section. No worries, run:

```bash
pip install -r requirements.txt
python prepro_tinyshakespeare.py
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2
OMP_NUM_THREADS=8 ./train_gpt2
Expand All @@ -73,10 +73,10 @@ The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent

You'll be using the (more bleeding edge) mixed precision version of the code:

```
```bash
sudo apt install openmpi-bin openmpi-doc libopenmpi-dev
pip install -r requirements.txt
python prepro_tinyshakespeare.py
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2cu
mpirun -np <number of GPUs on your machine> ./train_gpt2cu
Expand All @@ -89,17 +89,17 @@ Sub in the number of GPUs you'd like to run on in the last command.
Download and tokenize a dataset. The [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset is the fastest to download and tokenize:

```bash
python prepro_tinyshakespeare.py
python dev/data/tinyshakespeare.py
```

This prints:

```
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
Saved 32768 tokens to (...)/tiny_shakespeare_val.bin
Saved 305260 tokens to (...)/tiny_shakespeare_train.bin
```

The .bin files are raw byte streams of int32 numbers indicating the token ids with the GPT-2 tokenizer. Alternatively you could also tokenize the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset with `prepro_tinystories.py`.
The .bin files are raw byte streams of int32 numbers indicating the token ids with the GPT-2 tokenizer. Alternatively you could also tokenize the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset with `tinystories.py`.

In principle we'd be ready to train the model right here. However the baseline CPU/fp32 reference code is so inefficient that it's not practical to train these models from scratch yet. Instead, we initialize with the GPT-2 weights released by OpenAI and just do finetuning. For that, we have to download the GPT-2 weights and save them as a checkpoint we can load in C:

Expand Down
21 changes: 3 additions & 18 deletions dev/hellaswag.py → dev/data/hellaswag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import requests
import tiktoken
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F

from transformers import GPT2LMHeadModel
from data_common import download_file

DATA_CACHE_DIR = os.path.join("data", "hellaswag")
# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")

hellaswags = {
"train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
Expand All @@ -46,21 +46,6 @@

enc = tiktoken.get_encoding("gpt2")

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download(split):
"""Downloads HellaSwag DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand Down
21 changes: 3 additions & 18 deletions dev/mmlu.py → dev/data/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,18 @@
import tiktoken
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F

from transformers import GPT2LMHeadModel
from data_common import download_file

DATA_CACHE_DIR = os.path.join("data", "mmlu")
# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "mmlu")

enc = tiktoken.get_encoding("gpt2")
data_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download():
"""Downloads MMLU to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand Down
33 changes: 8 additions & 25 deletions prepro_tinyshakespeare.py → dev/data/tinyshakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,32 @@
- The download is from Github.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The output is written to a newly created tinyshakespeare/ folder.
The script prints:
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
Saved 32768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
Saved 305260 tokens to tinyshakespeare/tiny_shakespeare_train.bin
And runs in a few seconds depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""

import os
import requests
from tqdm import tqdm

import tiktoken
import numpy as np
from data_common import download_file

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinyshakespeare")

DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download():
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the TinyStories dataset, unless it's already downloaded
# download the TinyShakespeare dataset, unless it's already downloaded
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
if not os.path.exists(data_filename):
Expand All @@ -54,7 +38,6 @@ def download():
print(f"{data_filename} already exists, skipping download...")

def tokenize():
eot = enc._special_tokens['<|endoftext|>'] # end of text token
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
text = open(data_filename, 'r').read()
# let's treat every person's statement in the dialog as a separate document
Expand Down
33 changes: 10 additions & 23 deletions prepro_tinystories.py → dev/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The output is written to a newly created tinystories/ folder.
The script prints:
Tokenizing val split...
Saved 19043638 tokens to data/TinyStories_val.bin
Saved 19043638 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to data/TinyStories_train.bin
Saved 925653391 tokens to tinystories/TinyStories_train.bin
And runs in 1-2 minutes two depending on your internet
connection and computer. The .bin files are raw byte
Expand All @@ -23,29 +23,16 @@
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

import tiktoken
import numpy as np
from data_common import download_file

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinystories")

DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
Expand All @@ -70,11 +57,11 @@ def download():

# print a single example just for debugging and such
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
with open(shard_filenames[0], "r") as f:
data = json.load(f)
print("Download done.")
print(f"Number of shards: {len(shard_filenames)}")
#print(f"Example story:\n{data[0]}")
# with open(shard_filenames[0], "r") as f:
# data = json.load(f)
# print(f"Example story:\n{data[0]}")

def process_shard(shard_index, shard_filename):
with open(shard_filename, "r") as f:
Expand Down
8 changes: 4 additions & 4 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -1100,10 +1100,10 @@ int main() {
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");

// build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories
const char* tiny_stories_train = "data/TinyStories_train.bin";
const char* tiny_stories_val = "data/TinyStories_val.bin";
const char* tiny_shakespeare_train = "data/tiny_shakespeare_train.bin";
const char* tiny_shakespeare_val = "data/tiny_shakespeare_val.bin";
const char* tiny_stories_train = "dev/data/tinystories/TinyStories_train.bin";
const char* tiny_stories_val = "dev/data/tinystories/TinyStories_val.bin";
const char* tiny_shakespeare_train = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* tiny_shakespeare_val = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train;
const char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val;
int B = 4; // batch size 4 (i.e. 4 independent token sequences will be trained on)
Expand Down
Loading

0 comments on commit 722e5b2

Please sign in to comment.