Skip to content

Commit

Permalink
Override load_state_dict for the GPT LightningModule (Lightning-AI#806
Browse files Browse the repository at this point in the history
)
  • Loading branch information
carmocca authored Dec 18, 2023
1 parent 228c4c3 commit afca6f5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pretrain/openwebtext_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional

import lightning as L
import numpy as np
Expand Down Expand Up @@ -60,11 +60,17 @@ def configure_model(self) -> None:
self.module.apply(self.module._init_weights)

def configure_optimizers(self) -> torch.optim.Optimizer:
if self.module is None:
raise RuntimeError("You forgot to call `model.configure_model()`")

return torch.optim.AdamW(
self.module.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)

def on_fit_start(self) -> None:
if self.module is None:
raise RuntimeError("You forgot to call `model.configure_model()`")

trainer = self.trainer
with torch.device("meta"):
meta_model = GPT(self.module.config)
Expand Down Expand Up @@ -102,9 +108,14 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict = super().state_dict(*args, **kwargs)
# drop "module."
return {k[7:]: v for k, v in state_dict.items()}
if self.module is None:
raise RuntimeError("You forgot to call `model.configure_model()`")
return self.module.state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs):
if self.module is None:
raise RuntimeError("You forgot to call `model.configure_model()`")
return self.module.load_state_dict(state_dict, *args, **kwargs)


def main(devices: int = 1, precision: Optional[str] = None) -> None:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_openwebtext_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest


def test_lightningmodule_state_dict(tmp_path):
from lit_gpt.config import Config
from lit_gpt.model import GPT
from pretrain.openwebtext_trainer import LightningGPTModule

config = Config.from_name("pythia-14m")
model = GPT(config)
lm = LightningGPTModule(config)

# forgot configure_model
with pytest.raises(RuntimeError, match="forgot"):
lm.state_dict()
with pytest.raises(RuntimeError, match="forgot"):
lm.load_state_dict({})

lm.configure_model()

lm_state_dict = lm.state_dict()
# the state dict is the same so that the lightningmodule's checkpoints do not need to be converted
assert set(model.state_dict()) == set(lm_state_dict)
# the state dict can be loaded back
lm.load_state_dict(lm_state_dict, strict=True)

0 comments on commit afca6f5

Please sign in to comment.