Skip to content

Commit

Permalink
Automated linting
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 14, 2023
1 parent 247d7d4 commit fbfc729
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
41 changes: 21 additions & 20 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def main(fabric, resume):

model = torch.compile(model)
model = fabric.setup(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)
optimizer = fabric.setup_optimizers(optimizer)

state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}
Expand Down Expand Up @@ -164,10 +166,10 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
curr_iter = -1
fabric.barrier()
fabric.print(
f"Resuming data loader finished."
"Resuming data loader finished."
f"Took {time.perf_counter() - total_t0:.1f} seconds to reach iteration {initial_iter}."
)

if state["iter_num"] >= max_iters:
break

Expand All @@ -179,8 +181,8 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
state["iter_num"] += 1
iter_t0 = time.perf_counter()

input_ids = train_data[:, 0:model.config.block_size].contiguous().long()
targets = train_data[:, 1:(model.config.block_size + 1)].contiguous().long()
input_ids = train_data[:, 0 : model.config.block_size].contiguous().long()
targets = train_data[:, 1 : (model.config.block_size + 1)].contiguous().long()

is_accumulating = state["iter_num"] % gradient_accumulation_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
Expand All @@ -193,23 +195,25 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1

if state["iter_num"] % log_iter_interval == 0:
loss = loss.item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=(t1 - total_t0),
time=(t1 - total_t0),
flops=(measured_flops * log_iter_interval),
batches=state["iter_num"],
samples=(state["iter_num"] * micro_batch_size),
lengths=(state["iter_num"] * micro_batch_size * model.config.block_size),
)
metrics = {
"loss": loss,
"iter": state['iter_num'],
"step": state['step_count'],
"iter_time": (t1 - iter_t0),
"remaining_time": (t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']),
"iter": state["iter_num"],
"step": state["step_count"],
"iter_time": t1 - iter_t0,
"remaining_time": (
(t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"])
),
"tokens": state["iter_num"] * micro_batch_size * model.config.block_size,
"total_tokens": state["iter_num"] * micro_batch_size * model.config.block_size * fabric.world_size,
}
Expand All @@ -231,10 +235,7 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
td = time.perf_counter() - t0

fabric.print(f"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms")
metrics = {
"val_loss": val_loss,
"val_ppl": math.exp(val_loss),
}
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=state["iter_num"])
fabric.barrier()

Expand All @@ -253,8 +254,8 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max
for k, val_data in enumerate(val_dataloader):
if k >= max_iters:
break
input_ids = val_data[:, 0:model.config.block_size].contiguous().long()
targets = val_data[:, 1:(model.config.block_size + 1)].contiguous().long()
input_ids = val_data[:, 0 : model.config.block_size].contiguous().long()
targets = val_data[:, 1 : (model.config.block_size + 1)].contiguous().long()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
losses[k] = loss
Expand All @@ -273,13 +274,13 @@ def create_dataloaders(batch_size: int, block_size: int) -> Tuple[DataLoader, Da
train_datasets = [
StreamingDataset(
input_dir="data/slimpajama/train",
item_loader=TokensLoader(block_size=effective_block_size),
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir="data/starcoder",
item_loader=TokensLoader(block_size=effective_block_size),
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
drop_last=True,
),
Expand All @@ -292,7 +293,7 @@ def create_dataloaders(batch_size: int, block_size: int) -> Tuple[DataLoader, Da

val_dataset = StreamingDataset(
input_dir="data/slimpajama/val",
item_loader=TokensLoader(block_size=effective_block_size),
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
Expand Down
6 changes: 1 addition & 5 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ def test_pretrain_tiny_llama(tmp_path, fake_checkpoint_dir, monkeypatch):
with redirect_stdout(stdout):
module.setup()

assert {p.name for p in tmp_path.glob("*.pth")} == {
"step-00000001.pth",
"step-00000002.pth",
"step-00000003.pth",
}
assert {p.name for p in tmp_path.glob("*.pth")} == {"step-00000001.pth", "step-00000002.pth", "step-00000003.pth"}

logs = stdout.getvalue()
assert logs.count("optimizer.step") == module.max_iters
Expand Down
19 changes: 11 additions & 8 deletions tests/test_packed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,17 @@ def test_sharded_packed_dataset(monkeypatch):
assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"]


@pytest.mark.parametrize("weights, expected", [
([1], [1]),
([2], [1]),
([2, 0.5], [0.8, 0.2]),
([1, 1, 1], [1 / 3, 1 / 3, 1 / 3]),
([0.3, 0, 0], [1.0, 0, 0]),
(None, [0.5, 0.5]),
])
@pytest.mark.parametrize(
("weights", "expected"),
[
([1], [1]),
([2], [1]),
([2, 0.5], [0.8, 0.2]),
([1, 1, 1], [1 / 3, 1 / 3, 1 / 3]),
([0.3, 0, 0], [1.0, 0, 0]),
(None, [0.5, 0.5]),
],
)
def test_combined_dataset_normalizes_weights(weights, expected, monkeypatch):
from lit_gpt.packed_dataset import CombinedDataset

Expand Down

0 comments on commit fbfc729

Please sign in to comment.