forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download.py
96 lines (76 loc) · 2.92 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import urllib.request
from pathlib import Path
from typing import Optional, Sequence
from build.convert_hf_checkpoint import convert_hf_checkpoint
from config.model_config import (
ModelConfig,
ModelDistributionChannel,
resolve_model_config,
)
from requests.exceptions import HTTPError
def _download_hf_snapshot(
model_config: ModelConfig, models_dir: Path, hf_token: Optional[str]
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)
from huggingface_hub import snapshot_download
# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...")
try:
snapshot_download(
model_config.distribution_path,
local_dir=model_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401:
os.rmdir(model_dir)
raise RuntimeError(
"Access denied. Run huggingface-cli login to authenticate."
)
else:
raise e
# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...")
convert_hf_checkpoint(model_dir=model_dir, model_name=model_config.name, remove_bin_files=True)
def _download_direct(
model_config: ModelConfig,
urls: Sequence[str],
models_dir: Path,
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)
for url in urls:
filename = url.split("/")[-1]
local_path = model_dir / filename
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))
def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config = resolve_model_config(model)
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, models_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, model_config.distribution_path, models_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)
def is_model_downloaded(model: str, models_dir: Path) -> bool:
model_config = resolve_model_config(model)
# Check if the model directory exists and is not empty.
model_dir = models_dir / model_config.name
return os.path.isdir(model_dir) and os.listdir(model_dir)
def main(args):
download_and_convert(args.model, args.model_directory, args.hf_token)