Skip to content

Commit

Permalink
remote: use progress bar for remote cache query status during dvc gc (
Browse files Browse the repository at this point in the history
iterative#3559)

* remote: use progress bar for remote cache query status during `dvc gc`

* test that cache_checksums() handles unexpected paths

* Fix SSHMocked tests
  • Loading branch information
pmrowla authored Apr 1, 2020
1 parent b02babe commit ee0066f
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 99 deletions.
154 changes: 95 additions & 59 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,27 +697,49 @@ def checksum_to_path_info(self, checksum):
def list_cache_paths(self, prefix=None, progress_callback=None):
raise NotImplementedError

def all(self, *args, **kwargs):
# NOTE: The list might be way too big(e.g. 100M entries, md5 for each
# is 32 bytes, so ~3200Mb list) and we don't really need all of it at
# the same time, so it makes sense to use a generator to gradually
# iterate over it, without keeping all of it in memory.
for path in self.list_cache_paths(*args, **kwargs):
def cache_checksums(self, prefix=None, progress_callback=None):
"""Iterate over remote cache checksums.
If `prefix` is specified, only checksums which begin with `prefix`
will be returned.
"""
for path in self.list_cache_paths(prefix, progress_callback):
try:
yield self.path_to_checksum(path)
except ValueError:
logger.debug(
"'%s' doesn't look like a cache file, skipping", path
)

def gc(self, named_cache):
def all(self, jobs=None, name=None):
"""Iterate over all checksums in the remote cache.
Checksums will be fetched in parallel threads according to prefix
(except for small remotes) and a progress bar will be displayed.
"""
logger.debug(
"Fetching all checksums from '{}'".format(
name if name else "remote cache"
)
)

if not self.CAN_TRAVERSE:
return self.cache_checksums()

remote_size, remote_checksums = self._estimate_cache_size(name=name)
return self._cache_checksums_traverse(
remote_size, remote_checksums, jobs, name
)

def gc(self, named_cache, jobs=None):
logger.debug("named_cache: {} jobs: {}".format(named_cache, jobs))
used = self.extract_used_local_checksums(named_cache)

if self.scheme != "":
used.update(named_cache[self.scheme])

removed = False
for checksum in self.all():
for checksum in self.all(jobs, str(self.path_info)):
if checksum in used:
continue
path_info = self.checksum_to_path_info(checksum)
Expand Down Expand Up @@ -850,7 +872,7 @@ def cache_exists(self, checksums, jobs=None, name=None):

# Max remote size allowed for us to use traverse method
remote_size, remote_checksums = self._estimate_cache_size(
checksums, name=name
checksums, name
)

traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE
Expand All @@ -875,32 +897,25 @@ def cache_exists(self, checksums, jobs=None, name=None):
checksums - remote_checksums, jobs, name
)

if traverse_pages < 256 / self.JOBS:
# Threaded traverse will require making at least 255 more requests
# to the remote, so for small enough remotes, fetching the entire
# list at once will require fewer requests (but also take into
# account that this must be done sequentially rather than in
# parallel)
logger.debug(
"Querying {} checksums via default traverse".format(
len(checksums)
)
)
return list(checksums & set(self.all()))

return self._cache_exists_traverse(
checksums, remote_checksums, remote_size, jobs, name
logger.debug(
"Querying {} checksums via traverse".format(len(checksums))
)
remote_checksums = self._cache_checksums_traverse(
remote_size, remote_checksums, jobs, name
)
return list(checksums & set(remote_checksums))

def _all_with_limit(self, max_paths, prefix=None, progress_callback=None):
def _checksums_with_limit(
self, limit, prefix=None, progress_callback=None
):
count = 0
for checksum in self.all(prefix, progress_callback):
for checksum in self.cache_checksums(prefix, progress_callback):
yield checksum
count += 1
if count > max_paths:
if count > limit:
logger.debug(
"`all()` returned max '{}' checksums, "
"skipping remaining results".format(max_paths)
"`cache_checksums()` returned max '{}' checksums, "
"skipping remaining results".format(limit)
)
return

Expand All @@ -913,33 +928,32 @@ def _max_estimation_size(self, checksums):
* self.LIST_OBJECT_PAGE_SIZE,
)

def _estimate_cache_size(self, checksums, short_circuit=True, name=None):
def _estimate_cache_size(self, checksums=None, name=None):
"""Estimate remote cache size based on number of entries beginning with
"00..." prefix.
"""
prefix = "0" * self.TRAVERSE_PREFIX_LEN
total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN)
if short_circuit:
max_remote_size = self._max_estimation_size(checksums)
if checksums:
max_checksums = self._max_estimation_size(checksums)
else:
max_remote_size = None
max_checksums = None

with Tqdm(
desc="Estimating size of "
+ ("cache in '{}'".format(name) if name else "remote cache"),
unit="file",
total=max_remote_size,
) as pbar:

def update(n=1):
pbar.update(n * total_prefixes)

if max_remote_size:
checksums = self._all_with_limit(
max_remote_size / total_prefixes, prefix, update
if max_checksums:
checksums = self._checksums_with_limit(
max_checksums / total_prefixes, prefix, update
)
else:
checksums = self.all(prefix, update)
checksums = self.cache_checksums(prefix, update)

remote_checksums = set(checksums)
if remote_checksums:
Expand All @@ -949,38 +963,60 @@ def update(n=1):
logger.debug("Estimated remote size: {} files".format(remote_size))
return remote_size, remote_checksums

def _cache_exists_traverse(
self, checksums, remote_checksums, remote_size, jobs=None, name=None
def _cache_checksums_traverse(
self, remote_size, remote_checksums, jobs=None, name=None
):
logger.debug(
"Querying {} checksums via threaded traverse".format(
len(checksums)
)
)

traverse_prefixes = ["{:02x}".format(i) for i in range(1, 256)]
if self.TRAVERSE_PREFIX_LEN > 2:
traverse_prefixes += [
"{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN)
for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2))
]
"""Iterate over all checksums in the remote cache.
Checksums are fetched in parallel according to prefix, except in
cases where the remote size is very small.
All checksums from the remote (including any from the size
estimation step passed via the `remote_checksums` argument) will be
returned.
NOTE: For large remotes the list of checksums will be very
big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list)
and we don't really need all of it at the same time, so it makes
sense to use a generator to gradually iterate over it, without
keeping all of it in memory.
"""
num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE
if num_pages < 256 / self.JOBS:
# Fetching prefixes in parallel requires at least 255 more
# requests, for small enough remotes it will be faster to fetch
# entire cache without splitting it into prefixes.
#
# NOTE: this ends up re-fetching checksums that were already
# fetched during remote size estimation
traverse_prefixes = [None]
initial = 0
else:
yield from remote_checksums
initial = len(remote_checksums)
traverse_prefixes = ["{:02x}".format(i) for i in range(1, 256)]
if self.TRAVERSE_PREFIX_LEN > 2:
traverse_prefixes += [
"{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN)
for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2))
]
with Tqdm(
desc="Querying "
+ ("cache in '{}'".format(name) if name else "remote cache"),
total=remote_size,
initial=len(remote_checksums),
unit="objects",
initial=initial,
unit="file",
) as pbar:

def list_with_update(prefix):
return self.all(prefix=prefix, progress_callback=pbar.update)
return list(
self.cache_checksums(
prefix=prefix, progress_callback=pbar.update
)
)

with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
in_remote = executor.map(list_with_update, traverse_prefixes,)
remote_checksums.update(
itertools.chain.from_iterable(in_remote)
)
return list(checksums & remote_checksums)
yield from itertools.chain.from_iterable(in_remote)

def _cache_object_exists(self, checksums, jobs=None, name=None):
logger.debug(
Expand Down
10 changes: 7 additions & 3 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class RemoteLOCAL(RemoteBASE):
path_cls = PathInfo
PARAM_CHECKSUM = "md5"
PARAM_PATH = "path"
TRAVERSE_PREFIX_LEN = 2

UNPACKED_DIR_SUFFIX = ".unpacked"

Expand Down Expand Up @@ -57,14 +58,17 @@ def supported(cls, config):
return True

def list_cache_paths(self, prefix=None, progress_callback=None):
assert prefix is None
assert self.path_info is not None
if prefix:
path_info = self.path_info / prefix[:2]
else:
path_info = self.path_info
if progress_callback:
for path in walk_files(self.path_info):
for path in walk_files(path_info):
progress_callback()
yield path
else:
yield from walk_files(self.path_info)
yield from walk_files(path_info)

def get(self, md5):
if not md5:
Expand Down
11 changes: 8 additions & 3 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import logging
import os
import posixpath
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import closing, contextmanager
Expand Down Expand Up @@ -44,6 +45,7 @@ class RemoteSSH(RemoteBASE):
# paramiko stuff, so we would ideally have it double of server processors.
# We use conservative setting of 4 instead to not exhaust max sessions.
CHECKSUM_JOBS = 4
TRAVERSE_PREFIX_LEN = 2

DEFAULT_CACHE_TYPES = ["copy"]

Expand Down Expand Up @@ -258,15 +260,18 @@ def open(self, path_info, mode="r", encoding=None):
yield io.TextIOWrapper(fd, encoding=encoding)

def list_cache_paths(self, prefix=None, progress_callback=None):
assert prefix is None
if prefix:
root = posixpath.join(self.path_info.path, prefix[:2])
else:
root = self.path_info.path
with self.ssh(self.path_info) as ssh:
# If we simply return an iterator then with above closes instantly
if progress_callback:
for path in ssh.walk_files(self.path_info.path):
for path in ssh.walk_files(root):
progress_callback()
yield path
else:
yield from ssh.walk_files(self.path_info.path)
yield from ssh.walk_files(root)

def walk_files(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
18 changes: 9 additions & 9 deletions dvc/repo/gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
logger = logging.getLogger(__name__)


def _do_gc(typ, func, clist):
removed = func(clist)
def _do_gc(typ, func, clist, jobs=None):
removed = func(clist, jobs=jobs)
if not removed:
logger.info("No unused '{}' cache to remove.".format(typ))

Expand Down Expand Up @@ -74,22 +74,22 @@ def gc(
)
)

_do_gc("local", self.cache.local.gc, used)
_do_gc("local", self.cache.local.gc, used, jobs)

if self.cache.s3:
_do_gc("s3", self.cache.s3.gc, used)
_do_gc("s3", self.cache.s3.gc, used, jobs)

if self.cache.gs:
_do_gc("gs", self.cache.gs.gc, used)
_do_gc("gs", self.cache.gs.gc, used, jobs)

if self.cache.ssh:
_do_gc("ssh", self.cache.ssh.gc, used)
_do_gc("ssh", self.cache.ssh.gc, used, jobs)

if self.cache.hdfs:
_do_gc("hdfs", self.cache.hdfs.gc, used)
_do_gc("hdfs", self.cache.hdfs.gc, used, jobs)

if self.cache.azure:
_do_gc("azure", self.cache.azure.gc, used)
_do_gc("azure", self.cache.azure.gc, used, jobs)

if cloud:
_do_gc("remote", self.cloud.get_remote(remote, "gc -c").gc, used)
_do_gc("remote", self.cloud.get_remote(remote, "gc -c").gc, used, jobs)
Loading

0 comments on commit ee0066f

Please sign in to comment.