Skip to content

Commit

Permalink
Replace md5 with hash in cached_download function for any hash algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Feb 1, 2024
1 parent c6034ef commit 545bca7
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 10 deletions.
82 changes: 74 additions & 8 deletions gdown/cached_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import shutil
import sys
import tempfile
import warnings
from typing import Optional

import filelock

Expand All @@ -18,6 +20,10 @@


def md5sum(filename, blocksize=None):
warnings.warn(
"md5sum is deprecated and will be removed in the future.", FutureWarning
)

if blocksize is None:
blocksize = 65536

Expand All @@ -29,6 +35,10 @@ def md5sum(filename, blocksize=None):


def assert_md5sum(filename, md5, quiet=False, blocksize=None):
warnings.warn(
"assert_md5sum is deprecated and will be removed in the future.", FutureWarning
)

if not (isinstance(md5, str) and len(md5) == 32):
raise ValueError(f"MD5 must be 32 chars: {md5}")

Expand All @@ -43,7 +53,13 @@ def assert_md5sum(filename, md5, quiet=False, blocksize=None):


def cached_download(
url=None, path=None, md5=None, quiet=False, postprocess=None, **kwargs
url=None,
path=None,
md5=None,
quiet=False,
postprocess=None,
hash: Optional[str] = None,
**kwargs,
):
"""Cached download from URL.
Expand All @@ -54,11 +70,14 @@ def cached_download(
path: str, optional
Output filename. Default is basename of URL.
md5: str, optional
Expected MD5 for specified file.
Expected MD5 for specified file. Deprecated in favor of `hash`.
quiet: bool
Suppress terminal output. Default is False.
postprocess: callable
postprocess: callable, optional
Function called with filename as postprocess.
hash: str, optional
Hash value of file in the format of {algorithm}:{hash_value}
such as sha256:abcdef.... Supported algorithms: md5, sha1, sha256, sha512.
kwargs: dict
Keyword arguments to be passed to `download`.
Expand All @@ -76,14 +95,25 @@ def cached_download(
)
path = osp.join(cache_root, path)

if md5 is not None and hash is not None:
raise ValueError("md5 and hash cannot be specified at the same time.")

if md5 is not None:
warnings.warn(
"md5 is deprecated in favor of hash. Please use hash='md5:xxx...' instead.",
FutureWarning,
)
hash = f"md5:{md5}"
del md5

# check existence
if osp.exists(path) and not md5:
if osp.exists(path) and not hash:
if not quiet:
print(f"File exists: {path}", file=sys.stderr)
return path
elif osp.exists(path) and md5:
elif osp.exists(path) and hash:
try:
assert_md5sum(path, md5, quiet=quiet)
_assert_filehash(path=path, hash=hash, quiet=quiet)
return path
except AssertionError as e:
# show warning and overwrite if md5 doesn't match
Expand Down Expand Up @@ -114,11 +144,47 @@ def cached_download(
shutil.rmtree(temp_root)
raise

if md5:
assert_md5sum(path, md5, quiet=quiet)
if hash:
_assert_filehash(path=path, hash=hash, quiet=quiet)

# postprocess
if postprocess is not None:
postprocess(path)

return path


def _compute_filehash(path, algorithm):
BLOCKSIZE = 65536

if algorithm not in hashlib.algorithms_guaranteed:
raise ValueError(
f"Unsupported hash algorithm: {algorithm}. "
f"Supported algorithms: {hashlib.algorithms_guaranteed}"
)

algorithm_instance = getattr(hashlib, algorithm)()
with open(path, "rb") as f:
for block in iter(lambda: f.read(BLOCKSIZE), b""):
algorithm_instance.update(block)
return f"{algorithm}:{algorithm_instance.hexdigest()}"


def _assert_filehash(path, hash, quiet=False, blocksize=None):
if ":" not in hash:
raise ValueError(
f"Invalid hash: {hash}. "
"Hash must be in the format of {algorithm}:{hash_value}."
)
algorithm = hash.split(":")[0]

hash_actual = _compute_filehash(path=path, algorithm=algorithm)

if hash_actual == hash:
if not quiet:
print(f"File hash matches: {path!r} == {hash!r}", file=sys.stderr)
return True

raise AssertionError(
f"File hash doesn't match:\nactual: {hash_actual}\nexpected: {hash}"
)
4 changes: 2 additions & 2 deletions tests/test___main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import tempfile

from gdown.cached_download import assert_md5sum
from gdown.cached_download import _assert_filehash

here = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -15,7 +15,7 @@ def _test_cli_with_md5(url_or_id, md5, options=None):
if options is not None:
cmd = f"{cmd} {options}"
subprocess.call(shlex.split(cmd))
assert_md5sum(filename=f.name, md5=md5)
_assert_filehash(path=f.name, hash=f"md5:{md5}")


def _test_cli_with_content(url_or_id, content):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_cached_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tempfile

import gdown


def _cached_download(**kwargs):
url = "https://drive.google.com/uc?id=0B9P1L--7Wd2vU3VUVlFnbTgtS2c"
with tempfile.NamedTemporaryFile() as f:
for _ in range(2):
gdown.cached_download(url=url, path=f.name, **kwargs)


def test_cached_download_md5():
_cached_download(hash="md5:cb31a703b96c1ab2f80d164e9676fe7d")


def test_cached_download_sha1():
_cached_download(hash="sha1:69a5a1000f98237efea9231c8a39d05edf013494")


def test_cached_download_sha256():
_cached_download(
hash="sha256:284e3029cce3ae5ee0b05866100e300046359f53ae4c77fe6b34c05aa7a72cee"
)

0 comments on commit 545bca7

Please sign in to comment.