Skip to content

Commit

Permalink
Raise warning when loading a large model on a CPU device (Lightning-A…
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jun 26, 2024
1 parent b581076 commit 8ff04a9
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 2 deletions.
2 changes: 2 additions & 0 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litgpt.chat.base import generate as stream_generate_fn
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -171,6 +172,7 @@ def load(

if checkpoint_dir is not None:
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
load_checkpoint(fabric, model, checkpoint_path)
return cls(
model=model, tokenizer=tokenizer, devices=devices,
Expand Down
2 changes: 2 additions & 0 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -221,6 +222,7 @@ def main(
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

# Merge if this is a raw LoRA checkpoint
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -96,6 +97,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -96,6 +97,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from litgpt.tokenizer import Tokenizer
from litgpt.prompts import has_prompt_style, load_prompt_style, PromptStyle
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -217,6 +218,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
3 changes: 2 additions & 1 deletion litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -95,7 +96,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = finetuned_path

check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
Expand Down
17 changes: 17 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
import warnings

import lightning as L
import torch
Expand Down Expand Up @@ -544,3 +545,19 @@ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
not checkpoint_dir.is_absolute() and
new_checkpoint_dir.exists())
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir


def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660):
"""
Checks the file size and raises a warning if it exceeds the size_limit.
The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660
"""
size = 0.0
if os.path.exists(checkpoint_path):
size = os.path.getsize(checkpoint_path)
if size > size_limit and str(device) == "cpu":
warnings.warn(
f"The file size of {checkpoint_path} is over {size_limit/1024/1024/1024:.1f} GB. Using a model "
"with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU."
)
return size
29 changes: 28 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import redirect_stderr
from io import StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
from tempfile import TemporaryDirectory, NamedTemporaryFile
from unittest import mock

import pytest
Expand All @@ -25,6 +25,7 @@
CLI,
CycleIterator,
capture_hparams,
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
Expand Down Expand Up @@ -426,3 +427,29 @@ def test_extend_checkpoint_dir(input_path, expected):
])
def test_extend_checkpoint_dir_dont_exist(input_path, expected):
assert extend_checkpoint_dir(input_path) == expected


def test_file_size_below_limit_on_cpu():
# Test file size below limit on CPU
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_000_000_000):
size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu")
assert size == 4_000_000_000


def test_file_size_above_limit_on_cpu():
# Test file size above limit on CPU
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_600_000_000):
with pytest.warns(UserWarning) as record:
size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu")
assert size == 4_600_000_000
assert "over 4.2 GB" in str(record[0].message)


def test_file_size_above_limit_on_gpu():
# Test file size above limit on GPU should not warn
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_600_000_000):
size = check_file_size_on_cpu_and_warn(temp_file.name, "gpu")
assert size == 4_600_000_000

0 comments on commit 8ff04a9

Please sign in to comment.