Skip to content

Commit

Permalink
chat/base.py: extend checkpoint_dir before accessing it (Lightning-AI…
Browse files Browse the repository at this point in the history
…#1575)

Co-authored-by: rasbt <[email protected]>
  • Loading branch information
Andrei-Aksionov and rasbt authored Jul 13, 2024
1 parent 868981b commit d4c87bd
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
load_checkpoint,
)


Expand Down Expand Up @@ -209,6 +209,7 @@ def main(
multiline: Whether to support multiline input prompts.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)
Expand All @@ -223,15 +224,15 @@ def main(

fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

# Merge if this is a raw LoRA checkpoint
checkpoint_path = checkpoint_dir / "lit_model.pth"
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_dir)

checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

config = Config.from_file(checkpoint_dir / "model_config.yaml")

with fabric.init_module(empty_init=True):
Expand Down

0 comments on commit d4c87bd

Please sign in to comment.