Skip to content

Commit

Permalink
Support multiple epoch iteration in pretrain script (Lightning-AI#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Dec 16, 2023
1 parent cd733bc commit b0ccb00
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
44 changes: 43 additions & 1 deletion lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Dict,
Iterable,
List,
Mapping,
Optional,
TypeVar,
Union,
)
from typing_extensions import Self

import lightning as L
import torch
Expand Down Expand Up @@ -347,3 +359,33 @@ def estimate_flops(model: "GPT", training: bool) -> int:
# forward + backward
frozen_ops_per_step = 2 if training else 1
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops


class CycleIterator:
"""An iterator that cycles through an iterable indefinitely.
Example:
>>> iterator = CycleIterator([1, 2, 3])
>>> [next(iterator) for _ in range(5)]
[1, 2, 3, 1, 2]
Note:
Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
"""
def __init__(self, iterable: Iterable) -> None:
self.iterable = iterable
self.epoch = 0
self._iterator = None

def __next__(self) -> Any:
if self._iterator is None:
self._iterator = iter(self.iterable)
try:
return next(self._iterator)
except StopIteration:
self._iterator = iter(self.iterable)
self.epoch += 1
return next(self._iterator)

def __iter__(self) -> Self:
return self
9 changes: 5 additions & 4 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from lit_gpt.packed_dataset import CombinedDataset
from lit_gpt.utils import chunked_cross_entropy, num_parameters
from lit_gpt.utils import CycleIterator, chunked_cross_entropy, num_parameters

# System settings
model_name = "tiny-llama-1.1b"
Expand Down Expand Up @@ -150,7 +150,7 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
tokens_per_iter = micro_batch_size * model.config.block_size
max_iters = max_tokens_per_device // tokens_per_iter
initial_iter = state["iter_num"]
train_iterator = iter(train_dataloader)
train_iterator = CycleIterator(train_dataloader)

# resume data loader state by fast-forwarding through all seen batches
# drop this once streaming dataset supports proper resuming
Expand All @@ -163,7 +163,7 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
fabric.barrier()
fabric.print(
"Resuming data loader finished."
f" Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration {initial_iter}."
f" Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration {initial_iter}, epoch {train_iterator.epoch}."
)

running_loss = RunningMean(window=gradient_accumulation_steps, sync_on_compute=False).to(fabric.device)
Expand Down Expand Up @@ -212,6 +212,7 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
"loss": loss,
"iter": state["iter_num"],
"step": state["step_count"],
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"remaining_time": (
(t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"])
Expand All @@ -222,7 +223,7 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
}

fabric.print(
f"iter {metrics['iter']} step {metrics['step']}: loss {metrics['loss']:.4f}, iter time:"
f"iter {metrics['iter']} | step {metrics['step']}: loss {metrics['loss']:.4f}, iter time:"
f" {metrics['iter_time'] * 1000:.2f} ms,{' (optimizer.step)' if not is_accumulating else ''}"
f" remaining time: {metrics['remaining_time'] / 3600 / 24:.2f} days"
)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_pretrain_tiny_llama(tmp_path, monkeypatch):
module.log_step_interval = 1
module.log_iter_interval = 1
module.eval_iters = 2
module.max_iters = 3
module.max_tokens = 8
module.devices = 1
module.global_batch_size = 1
module.micro_batch_size = 1
Expand All @@ -90,9 +90,11 @@ def test_pretrain_tiny_llama(tmp_path, monkeypatch):
with redirect_stdout(stdout):
module.setup()

assert {p.name for p in tmp_path.glob("*.pth")} == {"step-00000001.pth", "step-00000002.pth", "step-00000003.pth"}
assert {p.name for p in tmp_path.glob("*.pth")} == {
"step-00000001.pth", "step-00000002.pth", "step-00000003.pth", "step-00000004.pth"
}

logs = stdout.getvalue()
assert logs.count("optimizer.step") == module.max_iters
assert logs.count("val loss") == module.max_iters
assert logs.count("optimizer.step") == 4
assert logs.count("val loss") == 4
assert "Total parameters: 1,888" in logs
19 changes: 19 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,22 @@ def test_num_parameters_bitsandbytes(mode):
with fabric.init_module(empty_init=True):
model = GPT.from_name("pythia-14m")
assert num_parameters(model) == 14067712


def test_cycle_iterator():
from lit_gpt.utils import CycleIterator

iterator = CycleIterator([])
with pytest.raises(StopIteration):
next(iterator)

iterator = CycleIterator(range(3))
assert iterator.epoch == 0
assert next(iterator) == 0
assert iterator.epoch == 0
assert next(iterator) == 1
assert iterator.epoch == 0
assert next(iterator) == 2
assert iterator.epoch == 0
assert next(iterator) == 0
assert iterator.epoch == 1

0 comments on commit b0ccb00

Please sign in to comment.