Skip to content

Commit

Permalink
Restore capability to load alternative weights (Lightning-AI#1620)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jul 23, 2024
1 parent 8ca83b8 commit 5ff6343
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ def download_from_hub(
print("\n".join(sorted(options, key=lambda x: x.lower())))
return

if repo_id not in options:
print(f"Unsupported repo_id: {repo_id}. Please choose a valid repo_id from the following list:")
print("\n".join(sorted(options, key=lambda x: x.lower())))
if model_name is None and repo_id not in options:
print(f"Unsupported `repo_id`: {repo_id}."
"\nIf you are trying to download alternative "
"weights for a supported model, please specify the corresponding model via the `--model_name` option, "
"for example, `litgpt download NousResearch/Hermes-2-Pro-Llama-3-8B --model_name Llama-3-8B`."
"\nAlternatively, please choose a valid `repo_id` from the list of supported models, which can be obtained via "
"`litgpt download list`.")
return

from huggingface_hub import snapshot_download
Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_download_model():
# Also test valid but unsupported repo IDs
command = ["litgpt", "download", "CohereForAI/aya-23-8B"]
output = run_command(command)
assert "Unsupported repo_id" in output
assert "Unsupported `repo_id`" in output


@pytest.mark.dependency()
Expand Down

0 comments on commit 5ff6343

Please sign in to comment.