Skip to content

Commit

Permalink
Check checkpoint_dir and add checkpoints to path (Lightning-AI#1454)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
rasbt and awaelchli authored Jun 4, 2024
1 parent 798d725 commit e567dbe
Show file tree
Hide file tree
Showing 47 changed files with 315 additions and 136 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ After installing LitGPT, select the model and action you want to take on that mo
```bash
# ligpt [action] [model]
litgpt download meta-llama/Meta-Llama-3-8B-Instruct
litgpt chat checkpoints/meta-llama/Meta-Llama-3-8B-Instruct
litgpt finetune checkpoints/meta-llama/Meta-Llama-3-8B-Instruct
litgpt pretrain checkpoints/meta-llama/Meta-Llama-3-8B-Instruct
litgpt serve checkpoints/meta-llama/Meta-Llama-3-8B-Instruct
litgpt chat meta-llama/Meta-Llama-3-8B-Instruct
litgpt finetune meta-llama/Meta-Llama-3-8B-Instruct
litgpt pretrain meta-llama/Meta-Llama-3-8B-Instruct
litgpt serve meta-llama/Meta-Llama-3-8B-Instruct
```

&nbsp;
Expand All @@ -162,7 +162,7 @@ litgpt download list
litgpt download microsoft/phi-2

# 3) Chat with the model
litgpt chat checkpoints/microsoft/phi-2
litgpt chat microsoft/phi-2

>> Prompt: What do Llamas eat?
```
Expand All @@ -188,7 +188,7 @@ litgpt download microsoft/phi-2
# 2) Finetune the model
curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json

litgpt finetune checkpoints/microsoft/phi-2 \
litgpt finetune microsoft/phi-2 \
--data JSON \
--data.json_path my_custom_dataset.json \
--data.val_split_fraction 0.1 \
Expand Down Expand Up @@ -220,7 +220,7 @@ litgpt download EleutherAI/pythia-160m \

# 2) Pretrain the model
litgpt pretrain EleutherAI/pythia-160m \
--tokenizer_dir checkpoints/EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--data TextFiles \
--data.train_data_path "custom_texts/" \
--train.max_tokens 10_000_000 \
Expand Down Expand Up @@ -252,8 +252,8 @@ litgpt download EleutherAI/pythia-160m

# 2) Continue pretraining the model
litgpt pretrain EleutherAI/pythia-160m \
--tokenizer_dir checkpoints/EleutherAI/pythia-160m \
--initial_checkpoint_dir checkpoints/EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--initial_checkpoint_dir EleutherAI/pythia-160m \
--data TextFiles \
--data.train_data_path "custom_texts/" \
--train.max_tokens 10_000_000 \
Expand All @@ -276,11 +276,11 @@ Once you're ready to deploy a finetuned LLM, run this command:

```bash
# locate the checkpoint to your finetuned or pretrained model and call the `serve` command:
litgpt serve checkpoints/microsoft/phi-2
litgpt serve microsoft/phi-2

# Alternative: if you haven't finetuned, download any checkpoint to deploy it:
litgpt download microsoft/phi-2
litgpt serve checkpoints/microsoft/phi-2
litgpt serve microsoft/phi-2
```

Test the server in a separate terminal and integrate the model API into your AI product:
Expand Down
11 changes: 10 additions & 1 deletion litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Iterator, List, Literal, Optional, Tuple

import lightning as L
Expand All @@ -13,7 +14,12 @@
from litgpt.generate.base import next_token
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
)


@torch.inference_mode()
Expand Down Expand Up @@ -196,6 +202,9 @@ def main(
compile: Whether to use compilation to speed up token generation. Will increase startup time.
multiline: Whether to support multiline input prompts.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/prepare_slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from litgpt import Tokenizer
from litgpt.data.prepare_starcoder import DataChunkRecipe
from litgpt.utils import CLI
from litgpt.utils import CLI, extend_checkpoint_dir


class SlimPajamaDataRecipe(DataChunkRecipe):
Expand Down Expand Up @@ -40,7 +40,7 @@ def prepare(
) -> None:
from litdata.processing.data_processor import DataProcessor

tokenizer = Tokenizer(tokenizer_path)
tokenizer_path = extend_checkpoint_dir(tokenizer_path)
data_recipe = SlimPajamaDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)
data_processor = DataProcessor(
input_dir=str(input_dir),
Expand Down
3 changes: 2 additions & 1 deletion litgpt/data/prepare_starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightning_utilities.core.imports import RequirementCache

from litgpt import Tokenizer
from litgpt.utils import CLI
from litgpt.utils import CLI, extend_checkpoint_dir

_LITDATA_AVAILABLE = RequirementCache("litdata")
if _LITDATA_AVAILABLE:
Expand Down Expand Up @@ -58,6 +58,7 @@ def prepare(
) -> None:
from litdata.processing.data_processor import DataProcessor

tokenizer_path = extend_checkpoint_dir(tokenizer_path)
tokenizer = Tokenizer(tokenizer_path)
data_recipe = StarcoderDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)
data_processor = DataProcessor(
Expand Down
10 changes: 9 additions & 1 deletion litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from pathlib import Path
from pprint import pprint
from typing import Dict, Any, Optional
from litgpt.utils import check_valid_checkpoint_dir

Expand All @@ -13,7 +14,11 @@
from litgpt.tokenizer import Tokenizer
from litgpt.generate.base import generate
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import load_checkpoint, get_default_supported_precision
from litgpt.utils import (
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
)


_LITSERVE_AVAILABLE = RequirementCache("litserve")
Expand Down Expand Up @@ -149,6 +154,9 @@ def run_server(
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
Expand Down
5 changes: 4 additions & 1 deletion litgpt/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import json
import os
from pathlib import Path
from pprint import pprint
from typing import Optional, Union
import torch

from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint
from litgpt.utils import copy_config_files
from litgpt.utils import copy_config_files, extend_checkpoint_dir


def prepare_results(results, save_filepath, print_results=True):
Expand Down Expand Up @@ -54,6 +55,8 @@ def convert_and_evaluate(
save_filepath: The file where the results will be saved.
Saves to `out_dir/results.json` by default.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

from lm_eval import evaluator

Expand Down
2 changes: 2 additions & 0 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
Expand Down Expand Up @@ -75,6 +76,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
3 changes: 2 additions & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
Expand Down Expand Up @@ -75,7 +76,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""

checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
2 changes: 2 additions & 0 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
init_out_dir,
Expand Down Expand Up @@ -73,6 +74,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
2 changes: 2 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
init_out_dir,
Expand Down Expand Up @@ -94,6 +95,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
11 changes: 10 additions & 1 deletion litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional

import lightning as L
Expand All @@ -13,7 +14,12 @@
from litgpt.adapter import GPT, Config
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, lazy_load
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
lazy_load
)


def main(
Expand Down Expand Up @@ -63,6 +69,9 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
11 changes: 10 additions & 1 deletion litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional

import lightning as L
Expand All @@ -13,7 +14,12 @@
from litgpt.adapter_v2 import GPT, Config
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, lazy_load
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
lazy_load
)


def main(
Expand Down Expand Up @@ -63,6 +69,9 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
11 changes: 10 additions & 1 deletion litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Any, Literal, Optional

import lightning as L
Expand All @@ -13,7 +14,12 @@

from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
)


def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -178,6 +184,9 @@ def main(
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
11 changes: 10 additions & 1 deletion litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional

import lightning as L
Expand All @@ -12,7 +13,12 @@
from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
)


def main(
Expand Down Expand Up @@ -62,6 +68,9 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
Loading

0 comments on commit e567dbe

Please sign in to comment.