Skip to content

Commit

Permalink
Add tokens/sec/device and new Fabric initialization (Lightning-AI#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored May 3, 2023
1 parent 2f7b808 commit d598aa2
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions train_redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

# Hyperparameters
learning_rate = 6e-4
batch_size = 128
micro_batch_size = 4
batch_size = 125
micro_batch_size = 5
max_iters = 600000 # num_epochs * epoch_size // devices
weight_decay = 1e-1
beta1 = 0.9
Expand Down Expand Up @@ -78,17 +78,6 @@ def main(

config = LLaMAConfig.from_name("7B")

with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
model.apply(model._init_weights)
torch.set_default_tensor_type(torch.FloatTensor)

# if compile:
# model = torch.compile(model)

model = fabric.setup_module(model)

train_dataloader, val_dataloader = create_dataloaders(
batch_size=micro_batch_size,
block_size=config.block_size,
Expand All @@ -98,13 +87,23 @@ def main(
)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)

with fabric.device:
torch.set_default_dtype(torch.bfloat16)
model = LLaMA(config)
model.apply(model._init_weights)
torch.set_default_dtype(torch.float32)

# if compile:
# model = torch.compile(model)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(beta1, beta2),
)
optimizer = fabric.setup_optimizers(optimizer)

model, optimizer = fabric.setup(model, optimizer)

process_batch_size = batch_size // devices
grad_accum_steps = process_batch_size // micro_batch_size
Expand All @@ -127,18 +126,24 @@ def train(

step_count = 0

step_time = 0.0
tokens = 0
tokens_sec = 0.0
prev_t1 = time.time()

for iter_num, train_data in enumerate(train_dataloader):
t0 = time.time()

# determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr

t0 = time.time()

input_ids = train_data[:, 0 : model.config.block_size].contiguous()
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()

is_accumulating = (iter_num + 1) % grad_accum_steps == 0
is_accumulating = (iter_num + 1) % grad_accum_steps != 0

with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
Expand All @@ -147,13 +152,17 @@ def train(
)
fabric.backward(loss / grad_accum_steps)

t1 = time.time()

if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)

optimizer.step()
optimizer.zero_grad()
step_count += 1

t1 = time.time()

if step_count % eval_interval == 0:
val_loss = validate(fabric, model, val_dataloader)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
Expand All @@ -168,15 +177,26 @@ def train(
fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")
)

dt = time.time() - t0
dt = t1 - t0

tokens += micro_batch_size * model.config.block_size
step_time += t1 - prev_t1
prev_t1 = t1

if iter_num % log_interval == 0:
tokens_sec_str = f"{tokens / step_time:.0f}" if not is_accumulating else "-"

fabric.log_dict(
{"iter": iter_num, "train_loss": loss, "step": step_count, "lr": lr}
)
fabric.print(
f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms"
f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms, speed: {tokens_sec_str} toks/s/device"
)

if not is_accumulating:
tokens = 0
step_time = 0.0

if iter_num > max_iters:
break

Expand Down

0 comments on commit d598aa2

Please sign in to comment.