Skip to content

Commit

Permalink
build(utils): only download weights index of the selected tensor form…
Browse files Browse the repository at this point in the history
…at (hyperonym#177)

* build(utils): only download weights index of the selected tensor format

* build(utils): preserve markdown files in downloads

---------

Co-authored-by: Gideon Giffard <[email protected]>
  • Loading branch information
peakji and fardeon authored Apr 21, 2023
1 parent 4b1ce7e commit 1c53cf1
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,38 @@
sys.exit("usage: python download.py REPO_ID LOCAL_DIR [REVISION]")

if os.getenv("TENSOR_FORMAT") == "safetensors":
tensor_format = "*.safetensors"
allow_patterns = ["*.safetensors", "*.safetensors.index.json"]
else:
tensor_format = "*.bin"
allow_patterns = ["*.bin", "*.bin.index.json"]

ignore_patterns = [
".*",
"*.index.json",
"*.bin",
"*.ckpt",
"*.h5",
"*.mlmodel",
"*.msgpack",
"*.onnx",
"*.ot",
"*.pb",
"*.safetensors",
"*.tar.gz",
"*.tflite",
]

kwargs = {
"repo_id": sys.argv[1],
"local_dir": sys.argv[2],
"revision": sys.argv[3] if len(sys.argv) > 3 else None,
"local_dir_use_symlinks": False,
"resume_download": True,
}

with tempfile.TemporaryDirectory() as cache_dir:
huggingface_hub.snapshot_download(
repo_id=sys.argv[1],
local_dir=sys.argv[2],
revision=sys.argv[3] if len(sys.argv) > 3 else None,
cache_dir=cache_dir,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=[tensor_format, "*.json", "*.model", "*.py"],
cache_dir=cache_dir, ignore_patterns=ignore_patterns, **kwargs
)
huggingface_hub.snapshot_download(
cache_dir=cache_dir, allow_patterns=allow_patterns, **kwargs
)

0 comments on commit 1c53cf1

Please sign in to comment.