diff --git a/README.md b/README.md index 33f86a1a52..db8102a1e0 100644 --- a/README.md +++ b/README.md @@ -201,20 +201,17 @@ Finetuning is the process of taking a pretrained AI model and further training i # 0) setup your dataset curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json -# 1) Download a pretrained model -litgpt download microsoft/phi-2 - -# 2) Finetune the model +# 1) Download and finetune the model litgpt finetune microsoft/phi-2 \ --data JSON \ --data.json_path my_custom_dataset.json \ --data.val_split_fraction 0.1 \ --out_dir out/custom-model -# 3) Test the model +# 2) Test the model litgpt chat out/custom-model/final -# 4) Deploy the model +# 3) Deploy the model litgpt serve out/custom-model/final ``` @@ -238,7 +235,6 @@ Deploy a pretrained or finetune LLM to use it in real-world applications. Deploy ```bash # deploy an out-of-the-box LLM -litgpt download microsoft/phi-2 litgpt serve microsoft/phi-2 # deploy your own trained model @@ -306,11 +302,10 @@ litgpt chat microsoft/phi-2   ```bash -# 1) Download the LLM +# 1) List all supported LLMs litgpt download list -litgpt download microsoft/phi-2 -# 2) Test the model +# 2) Download and use the model litgpt chat microsoft/phi-2 >> Prompt: What do Llamas eat? @@ -393,10 +388,7 @@ mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt -# 1) Download a pretrained model -litgpt download EleutherAI/pythia-160m - -# 2) Continue pretraining the model +# 1) Download and continue pretraining a model litgpt pretrain EleutherAI/pythia-160m \ --tokenizer_dir EleutherAI/pythia-160m \ --initial_checkpoint_dir EleutherAI/pythia-160m \ @@ -405,7 +397,7 @@ litgpt pretrain EleutherAI/pythia-160m \ --train.max_tokens 10_000_000 \ --out_dir out/custom-model -# 3) Test the model +# 2) Test the model litgpt chat out/custom-model/final ``` diff --git a/litgpt/api.py b/litgpt/api.py index 68f02c10bc..d4b4752154 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -1,7 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. # # This file implements the LitGPT Python API -import os from pathlib import Path from typing import Any, List, Literal, Optional, Union @@ -16,8 +15,8 @@ from litgpt.chat.base import generate as stream_generate_fn from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle from litgpt.utils import ( + auto_download_checkpoint, check_file_size_on_cpu_and_warn, - check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, @@ -109,17 +108,7 @@ def load( allowed_init = {"pretrained", "random"} if init == "pretrained": - from litgpt.scripts.download import download_from_hub # Moved here due to the circular import issue in LitGPT that we need to solve some time - - checkpoint_dir = extend_checkpoint_dir(Path(model)) - try: - check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True) - except FileNotFoundError: - if not access_token: - access_token = os.getenv("HF_TOKEN") - download_from_hub(repo_id=model, access_token=access_token) - - checkpoint_dir = Path("checkpoints") / model + checkpoint_dir = auto_download_checkpoint(model_name=model, access_token=access_token) config = Config.from_file(checkpoint_dir / "model_config.yaml") elif init == "random": diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index d5b2f047fb..3af15018d4 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -18,9 +18,9 @@ from litgpt.prompts import has_prompt_style, load_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.utils import ( + auto_download_checkpoint, check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, - extend_checkpoint_dir, get_default_supported_precision, load_checkpoint ) @@ -176,11 +176,13 @@ def main( precision: Optional[str] = None, compile: bool = False, multiline: bool = False, + access_token: Optional[str] = None, ) -> None: """Chat with a model. Args: - checkpoint_dir: The checkpoint directory to load. + checkpoint_dir: A local path to a directory containing the model weights or a valid model name. + You can get a list of valid model names via the `litgpt download list` command line argument. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -205,8 +207,8 @@ def main( precision: Indicates the Fabric precision setting to use. compile: Whether to use compilation to speed up token generation. Will increase startup time. multiline: Whether to support multiline input prompts. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) @@ -229,7 +231,7 @@ def main( print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.") merge_lora(checkpoint_dir) - check_valid_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) config = Config.from_file(checkpoint_dir / "model_config.yaml") with fabric.init_module(empty_init=True): diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index fa7341bf0b..ea5109a20b 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -16,7 +16,7 @@ from litgpt.chat.base import generate as stream_generate from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle from litgpt.utils import ( - extend_checkpoint_dir, + auto_download_checkpoint, get_default_supported_precision, load_checkpoint ) @@ -173,7 +173,8 @@ def run_server( devices: int = 1, accelerator: str = "auto", port: int = 8000, - stream: bool = False + stream: bool = False, + access_token: Optional[str] = None, ) -> None: """Serve a LitGPT model using LitServe. @@ -207,12 +208,11 @@ def run_server( The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU. port: The network port number on which the model is configured to be served. stream: Whether to stream the responses. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) - check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth") - if not stream: server = LitServer( SimpleLitAPI( diff --git a/litgpt/eval/evaluate.py b/litgpt/eval/evaluate.py index 6de98212d3..31ddf015f5 100644 --- a/litgpt/eval/evaluate.py +++ b/litgpt/eval/evaluate.py @@ -8,7 +8,7 @@ import torch from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint -from litgpt.utils import copy_config_files, extend_checkpoint_dir +from litgpt.utils import copy_config_files, auto_download_checkpoint def prepare_results(results, save_filepath, print_results=True): @@ -37,6 +37,7 @@ def convert_and_evaluate( limit: Optional[float] = None, seed: int = 1234, save_filepath: Optional[Path] = None, + access_token: Optional[str] = None, ) -> None: """Evaluate a model with the LM Evaluation Harness. @@ -55,6 +56,7 @@ def convert_and_evaluate( seed: Random seed. save_filepath: The file where the results will be saved. Saves to `out_dir/results.json` by default. + access_token: Optional API token to access models with restrictions. """ if tasks is None: from lm_eval.tasks import TaskManager @@ -68,7 +70,7 @@ def convert_and_evaluate( ) return - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) if not (isinstance(batch_size, int) and batch_size > 0) and not (isinstance(batch_size, str) and batch_size.startswith("auto")): diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index f13f66a72d..322b353616 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -24,12 +24,12 @@ from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + auto_download_checkpoint, CycleIterator, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, - extend_checkpoint_dir, get_default_supported_precision, init_out_dir, instantiate_torch_optimizer, @@ -62,6 +62,7 @@ def setup( optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, + access_token: Optional[str] = None, ) -> None: """Finetune a model using the Adapter method. @@ -79,8 +80,9 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index fc355af2d1..c0bf1c22d0 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -24,12 +24,12 @@ from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + auto_download_checkpoint, CycleIterator, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, - extend_checkpoint_dir, get_default_supported_precision, init_out_dir, instantiate_torch_optimizer, @@ -62,6 +62,7 @@ def setup( optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, + access_token: Optional[str] = None, ) -> None: """Finetune a model using the Adapter V2 method. @@ -79,8 +80,9 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 038848707a..05c8c17d20 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -20,12 +20,12 @@ from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + auto_download_checkpoint, CycleIterator, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, - extend_checkpoint_dir, find_resume_path, get_default_supported_precision, load_checkpoint, @@ -58,6 +58,7 @@ def setup( optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, + access_token: Optional[str] = None, ) -> None: """Finetune a model. @@ -77,8 +78,9 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 276f58a937..9cc153ab9a 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -25,12 +25,12 @@ from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + auto_download_checkpoint, CycleIterator, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, - extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, init_out_dir, @@ -72,6 +72,7 @@ def setup( optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, + access_token: Optional[str] = None, ) -> None: """Finetune a model using the LoRA method. @@ -98,8 +99,9 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + access_token: Optional API token to access models with restrictions. """ - checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) diff --git a/litgpt/utils.py b/litgpt/utils.py index db4e54d9b2..69cab686c5 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -26,6 +26,7 @@ from torch.serialization import normalize_storage_type from typing_extensions import Self + if TYPE_CHECKING: from litgpt import GPT, Config @@ -561,3 +562,22 @@ def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_71 "with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU." ) return size + + +def auto_download_checkpoint(model_name, access_token=None): + from litgpt.scripts.download import download_from_hub # moved here due to circular import issue + + checkpoint_dir = extend_checkpoint_dir(Path(model_name)) + try: + check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True) + except FileNotFoundError as e: + if access_token is None: + access_token = os.getenv("HF_TOKEN") + + if checkpoint_dir.parts[0] != "checkpoints" and not checkpoint_dir.is_absolute(): + download_from_hub(repo_id=str(model_name), access_token=access_token) + checkpoint_dir = Path("checkpoints") / checkpoint_dir + else: + raise e + + return checkpoint_dir diff --git a/tests/test_chat.py b/tests/test_chat.py index 311969a4d0..03a034600e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,7 +2,6 @@ import os import re import subprocess -import sys from contextlib import redirect_stderr, redirect_stdout from io import StringIO from itertools import repeat diff --git a/tests/test_config_hub.py b/tests/test_config_hub.py index 214359dffc..718a671616 100644 --- a/tests/test_config_hub.py +++ b/tests/test_config_hub.py @@ -54,7 +54,11 @@ def test_config_help(script_file, config_file, monkeypatch): monkeypatch.setattr(module, "Config", Mock(return_value=Config.from_name("pythia-14m"))) monkeypatch.setattr(module, "check_valid_checkpoint_dir", Mock(), raising=False) - with mock.patch("sys.argv", [script_file.name, "--config", str(config_file), "--devices", "1"]): - CLI(module.setup) - - module.main.assert_called_once() + try: + with mock.patch("sys.argv", [script_file.name, "--config", str(config_file), "--devices", "1"]): + CLI(module.setup) + module.main.assert_called_once() + except FileNotFoundError: + pass + # FileNotFound occurs here because we have not downloaded the model weights referenced in the config files + # which is ok because here we just want to validate the config file itself.