forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download.py
197 lines (159 loc) · 6.64 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import sys
import urllib.request
from pathlib import Path
from typing import Optional
from build.convert_hf_checkpoint import convert_hf_checkpoint
from config.model_config import (
load_model_configs,
ModelConfig,
ModelDistributionChannel,
resolve_model_config,
)
def _download_hf_snapshot(
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
from huggingface_hub import snapshot_download
from requests.exceptions import HTTPError
# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
try:
snapshot_download(
model_config.distribution_path,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
print(
"Access denied. Create a HuggingFace account and run 'pip3 install huggingface_hub' and 'huggingface-cli login' to authenticate.",
file=sys.stderr,
)
exit(1)
elif e.response.status_code == 403: # No access to the specific model.
# The error message includes a link to request access to the given model. This prints nicely and does not include
# a traceback.
print(str(e), file=sys.stderr)
exit(1)
else:
raise e
# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
convert_hf_checkpoint(
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
)
def _download_direct(
model_config: ModelConfig,
artifact_dir: Path,
):
for url in model_config.distribution_path:
filename = url.split("/")[-1]
local_path = artifact_dir / filename
print(f"Downloading {url}...", file=sys.stderr)
urllib.request.urlretrieve(url, str(local_path.absolute()))
def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name
# Download into a temporary directory. We'll move to the final
# location once the download and conversion is complete. This
# allows recovery in the event that the download or conversion
# fails unexpectedly.
temp_dir = models_dir / "downloads" / model_config.name
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
try:
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, temp_dir, hf_token)
elif (
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
):
_download_direct(model_config, temp_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)
# Move from the temporary directory to the intended location,
# overwriting if necessary.
if os.path.isdir(model_dir):
shutil.rmtree(model_dir)
shutil.move(temp_dir, model_dir)
finally:
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
def is_model_downloaded(model: str, models_dir: Path) -> bool:
model_config = resolve_model_config(model)
# Check if the model directory exists and is not empty.
model_dir = models_dir / model_config.name
return os.path.isdir(model_dir) and os.listdir(model_dir)
# Subcommand to list available models.
def list_main(args) -> None:
model_configs = load_model_configs()
# Build the table in-memory so that we can align the text nicely.
name_col = []
aliases_col = []
installed_col = []
for name, config in model_configs.items():
is_downloaded = is_model_downloaded(name, args.model_directory)
name_col.append(name)
aliases_col.append(", ".join(config.aliases))
installed_col.append("Yes" if is_downloaded else "")
cols = {"Model": name_col, "Aliases": aliases_col, "Downloaded": installed_col}
# Find the length of the longest value in each column.
col_widths = {
key: max(*[len(s) for s in vals], len(key)) + 1 for (key, vals) in cols.items()
}
# Display header.
print()
print(*[val.ljust(width) for (val, width) in col_widths.items()])
print(*["-" * width for width in col_widths.values()])
for i in range(len(name_col)):
row = [col[i] for col in cols.values()]
print(*[val.ljust(width) for (val, width) in zip(row, col_widths.values())])
print()
# Subcommand to remove downloaded model artifacts.
def remove_main(args) -> None:
# TODO It would be nice to have argparse validate this. However, we have
# model as an optional named parameter for all subcommands, so we'd
# probably need to move it to be registered per-command.
if not args.model:
print("Usage: torchchat.py remove <model-or-alias>")
return
model_config = resolve_model_config(args.model)
model_dir = args.model_directory / model_config.name
if not os.path.isdir(model_dir):
print(f"Model {args.model} has no downloaded artifacts.")
return
print(f"Removing downloaded model artifacts for {args.model}...")
shutil.rmtree(model_dir)
print("Done.")
# Subcommand to print downloaded model artifacts directory.
# Asking for location will/should trigger download of model if not available.
def where_main(args) -> None:
# TODO It would be nice to have argparse validate this. However, we have
# model as an optional named parameter for all subcommands, so we'd
# probably need to move it to be registered per-command.
if not args.model:
print("Usage: torchchat.py where <model-or-alias>")
return
model_config = resolve_model_config(args.model)
model_dir = args.model_directory / model_config.name
if not os.path.isdir(model_dir):
raise RuntimeError(f"Model {args.model} has no downloaded artifacts.")
print(str(os.path.abspath(model_dir)))
exit(0)
# Subcommand to download model artifacts.
def download_main(args) -> None:
download_and_convert(args.model, args.model_directory, args.hf_token)