Skip to content

Commit

Permalink
upgrade our .py script so it can also do torchrun with many processes…
Browse files Browse the repository at this point in the history
…, as our mixed precision mpi nccl code can do now
  • Loading branch information
karpathy committed Apr 26, 2024
1 parent 0c3e3e3 commit 0852706
Showing 1 changed file with 82 additions and 35 deletions.
117 changes: 82 additions & 35 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
Example launches to only benchmark the speed of bfloat16 compiled GPU training:
1 GPU:
python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
4 GPU:
torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
"""

import os
Expand All @@ -20,6 +26,8 @@
import torch.nn as nn
from torch.nn import functional as F
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

class NewGELU(nn.Module):
"""Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
Expand Down Expand Up @@ -358,11 +366,17 @@ def write_tokenizer(enc, filename):
file.write(b) # Write the actual bytes
print(f"wrote {filename}")

def print0(*args, **kwargs):
# modified print that only prints from the master process
# if this is not a distributed run, it's just a print
if int(os.environ.get("RANK", 0)) == 0:
print(*args, **kwargs)

if __name__ == "__main__":
import time
import argparse
import tiktoken
print(f"Running pytorch {torch.version.__version__}")
print0(f"Running pytorch {torch.version.__version__}")

# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
Expand All @@ -383,26 +397,47 @@ def write_tokenizer(enc, filename):
assert 1 <= T <= 1024
assert args.dtype in {"float32", "float16", "bfloat16"}

# select the device
if args.device:
device = args.device
# set up DDP (distributed data parallel). torchrun sets this env variable
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
# use of DDP atm demands CUDA, we set the device appropriately according to rank
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
seed_offset = ddp_rank # each process gets a different seed
else:
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
ddp_world_size = 1
master_process = True
seed_offset = 0
# select the device
if args.device:
# provided explicitly by the user
device = args.device
else:
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")

# set up a context manager following the desired dtype and device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext()

# seed the random number generators
torch.manual_seed(42)
# seed the random number generators (in DDP we want different processes to use different offsets)
# in the code below we don't actually use random numbers because there is no active dataloader
# loading actual batches of data, etc. but it is a good practice and something to be careful with,
# explicit with and think about, so I am leaving this here.
torch.manual_seed(42 + seed_offset)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.cuda.manual_seed(42 + seed_offset)

# set the torch precision mode to use TensorFloat32 (TF32) for matmuls
# docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
Expand All @@ -413,7 +448,8 @@ def write_tokenizer(enc, filename):
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
write_tokenizer(enc, "gpt2_tokenizer.bin")
if master_process and args.write_tensors: # tokenizer is technically not tensors but ok
write_tokenizer(enc, "gpt2_tokenizer.bin")

# load the GPT-2 model weights
model = GPT.from_pretrained("gpt2")
Expand All @@ -422,7 +458,7 @@ def write_tokenizer(enc, filename):
if args.compile:
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
print("compiling the model...")
print0("compiling the model...")
model = torch.compile(model)

# -------------------------------------------------------------------------
Expand All @@ -436,7 +472,7 @@ def write_tokenizer(enc, filename):
assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
assert os.path.isfile(tokens_bin)
print(f"loading cached tokens in {tokens_bin}")
print0(f"loading cached tokens in {tokens_bin}")
with open(tokens_bin, "rb") as f:
tokens = np.frombuffer(f.read(), dtype=np.int32)

Expand Down Expand Up @@ -466,7 +502,7 @@ def get_batch():
# STAGE 1: weights / state logging for C to load later

# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
if master_process and (not args.inference_only and args.write_tensors):
logits, loss = model(x, y)
loss.backward()
# save model params, in both float32 and bfloat16
Expand All @@ -479,6 +515,11 @@ def get_batch():
# -------------------------------------------------------------------------
# STAGE 2: training loop to get timings

# here we wrap model into DDP container
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model # always contains the "raw" unwrapped model

# init the optimizer
adam_use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=adam_use_fused)
Expand All @@ -492,7 +533,7 @@ def get_batch():
logits, loss = model(x, y)
del logits
if not args.inference_only:
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# wait on the CPU for all device work to end so we get accurate per-iteration timings below
Expand All @@ -505,27 +546,33 @@ def get_batch():
# the 0th iteration is often an outlier (much slower) => skip logging it
if i > 0 and i > args.num_iterations - 20:
timings.append(t1-t0)
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
print0(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")

# print the average of the last 20 timings, to get something smooth-ish
timings = timings[-20:]
print(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
# STAGE 3: Few steps of inference
if master_process:

# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
start = "<|endoftext|>"
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation for 16 time steps (tokens)
max_new_tokens = 16
temperature = 1.0
top_k = 40
raw_model.eval()
y = raw_model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print0(decode(y[0].tolist()))
print0('---------------')

# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
start = "<|endoftext|>"
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation for 16 time steps (tokens)
max_new_tokens = 16
temperature = 1.0
top_k = 40
model.eval()
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print('---------------')
# -------------------------------------------------------------------------
# clean up nice
if ddp:
destroy_process_group()

0 comments on commit 0852706

Please sign in to comment.