Skip to content

Commit

Permalink
remote: use progress bar when paginating (iterative#3532)
Browse files Browse the repository at this point in the history
* remote: use progress bar when paginating

- add `progress_callback` parameter to `list_cache_paths()` so that
  remotes can update remote cache traverse progress bar after fetching a
  page

* Show remote size estimation status as no-total progress bar

* Update pbar per item rather than per page

* make tests more explicit (don't use mock.ANY)

* Apply suggestions from code review

minor fixes/tidy

* Use progress_callback in ssh/local list_cache_paths()

- keep ssh/local consistent with other remotes (even though `dvc
  gc`/`RemoteBASE.all()` do not currently use progress bars)

* Fix review issues

Co-authored-by: Casper da Costa-Luis <[email protected]>
  • Loading branch information
pmrowla and casperdcl authored Mar 28, 2020
1 parent af6c429 commit acdc876
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 43 deletions.
11 changes: 8 additions & 3 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class RemoteAZURE(RemoteBASE):
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5
LIST_OBJECT_PAGE_SIZE = 5000

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -88,7 +89,7 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.blob_service.delete_blob(path_info.bucket, path_info.path)

def _list_paths(self, bucket, prefix):
def _list_paths(self, bucket, prefix, progress_callback=None):
blob_service = self.blob_service
next_marker = None
while True:
Expand All @@ -97,21 +98,25 @@ def _list_paths(self, bucket, prefix):
)

for blob in blobs:
if progress_callback:
progress_callback()
yield blob.name

if not blobs.next_marker:
break

next_marker = blobs.next_marker

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(self.path_info.bucket, prefix)
return self._list_paths(
self.path_info.bucket, prefix, progress_callback
)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
40 changes: 24 additions & 16 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def path_to_checksum(self, path):
def checksum_to_path_info(self, checksum):
return self.path_info / checksum[0:2] / checksum[2:]

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
raise NotImplementedError

def all(self):
Expand Down Expand Up @@ -850,9 +850,20 @@ def cache_exists(self, checksums, jobs=None, name=None):
checksums = frozenset(checksums)
prefix = "0" * self.TRAVERSE_PREFIX_LEN
total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN)
remote_checksums = set(
map(self.path_to_checksum, self.list_cache_paths(prefix=prefix))
)
with Tqdm(
desc="Estimating size of "
+ ("cache in '{}'".format(name) if name else "remote cache"),
unit="file",
) as pbar:

def update(n=1):
pbar.update(n * total_prefixes)

paths = self.list_cache_paths(
prefix=prefix, progress_callback=update
)
remote_checksums = set(map(self.path_to_checksum, paths))

if remote_checksums:
remote_size = total_prefixes * len(remote_checksums)
else:
Expand Down Expand Up @@ -895,11 +906,11 @@ def cache_exists(self, checksums, jobs=None, name=None):
return list(checksums & set(self.all()))

return self._cache_exists_traverse(
checksums, remote_checksums, jobs, name
checksums, remote_checksums, remote_size, jobs, name
)

def _cache_exists_traverse(
self, checksums, remote_checksums, jobs=None, name=None
self, checksums, remote_checksums, remote_size, jobs=None, name=None
):
logger.debug(
"Querying {} checksums via threaded traverse".format(
Expand All @@ -915,20 +926,17 @@ def _cache_exists_traverse(
]
with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
total=len(traverse_prefixes),
unit="dir",
+ ("cache in '{}'".format(name) if name else "remote cache"),
total=remote_size,
initial=len(remote_checksums),
unit="objects",
) as pbar:

def list_with_update(prefix):
ret = map(
self.path_to_checksum,
list(self.list_cache_paths(prefix=prefix)),
paths = self.list_cache_paths(
prefix=prefix, progress_callback=pbar.update
)
pbar.update_desc(
"Querying cache in '{}'".format(self.path_info / prefix)
)
return ret
return map(self.path_to_checksum, list(paths))

with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
in_remote = executor.map(list_with_update, traverse_prefixes,)
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if not self.cache["ids"]:
return

Expand All @@ -479,6 +479,8 @@ def list_cache_paths(self, prefix=None):
query = "({}) and trashed=false".format(parents_query)

for item in self.gdrive_list_item(query):
if progress_callback:
progress_callback()
parent_id = item["parents"][0]["id"]
yield posixpath.join(self.cache["ids"][parent_id], item["title"])

Expand Down
12 changes: 9 additions & 3 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,24 @@ def remove(self, path_info):

blob.delete()

def _list_paths(self, path_info, max_items=None, prefix=None):
def _list_paths(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
if prefix:
prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:])
else:
prefix = path_info.path
for blob in self.gs.bucket(path_info.bucket).list_blobs(
prefix=path_info.path, max_results=max_items
):
if progress_callback:
progress_callback()
yield blob.name

def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)
def list_cache_paths(self, prefix=None, progress_callback=None):
return self._list_paths(
self.path_info, prefix=prefix, progress_callback=progress_callback
)

def walk_files(self, path_info):
for fname in self._list_paths(path_info / ""):
Expand Down
7 changes: 5 additions & 2 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def open(self, path_info, mode="r", encoding=None):
raise FileNotFoundError(*e.args)
raise

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if not self.exists(self.path_info):
return

Expand All @@ -166,10 +166,13 @@ def list_cache_paths(self, prefix=None):
with self.hdfs(self.path_info) as hdfs:
while dirs:
try:
for entry in hdfs.ls(dirs.pop(), detail=True):
entries = hdfs.ls(dirs.pop(), detail=True)
for entry in entries:
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
if progress_callback:
progress_callback()
yield urlparse(entry["name"]).path
except IOError as e:
# When searching for a specific prefix pyarrow raises an
Expand Down
9 changes: 7 additions & 2 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ def cache_dir(self, value):
def supported(cls, config):
return True

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
assert prefix is None
assert self.path_info is not None
return walk_files(self.path_info)
if progress_callback:
for path in walk_files(self.path_info):
progress_callback()
yield path
else:
yield from walk_files(self.path_info)

def get(self, md5):
if not md5:
Expand Down
9 changes: 6 additions & 3 deletions dvc/remote/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RemoteOSS(RemoteBASE):
REQUIRES = {"oss2": "oss2"}
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5
LIST_OBJECT_PAGE_SIZE = 100

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -90,20 +91,22 @@ def remove(self, path_info):
logger.debug("Removing oss://{}".format(path_info))
self.oss_service.delete_object(path_info.path)

def _list_paths(self, prefix):
def _list_paths(self, prefix, progress_callback=None):
import oss2

for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix):
if progress_callback:
progress_callback()
yield blob.key

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(prefix)
return self._list_paths(prefix, progress_callback)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
28 changes: 21 additions & 7 deletions dvc/remote/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path)

def _list_objects(self, path_info, max_items=None, prefix=None):
def _list_objects(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
""" Read config for list object api, paginate through list objects."""
kwargs = {
"Bucket": path_info.bucket,
Expand All @@ -202,16 +204,28 @@ def _list_objects(self, path_info, max_items=None, prefix=None):
kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2])
paginator = self.s3.get_paginator(self.list_objects_api)
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", ())

def _list_paths(self, path_info, max_items=None, prefix=None):
contents = page.get("Contents", ())
if progress_callback:
for item in contents:
progress_callback()
yield item
else:
yield from contents

def _list_paths(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
return (
item["Key"]
for item in self._list_objects(path_info, max_items, prefix)
for item in self._list_objects(
path_info, max_items, prefix, progress_callback
)
)

def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)
def list_cache_paths(self, prefix=None, progress_callback=None):
return self._list_paths(
self.path_info, prefix=prefix, progress_callback=progress_callback
)

def isfile(self, path_info):
from botocore.exceptions import ClientError
Expand Down
9 changes: 7 additions & 2 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,16 @@ def open(self, path_info, mode="r", encoding=None):
else:
yield io.TextIOWrapper(fd, encoding=encoding)

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
assert prefix is None
with self.ssh(self.path_info) as ssh:
# If we simply return an iterator then with above closes instantly
yield from ssh.walk_files(self.path_info.path)
if progress_callback:
for path in ssh.walk_files(self.path_info.path):
progress_callback()
yield path
else:
yield from ssh.walk_files(self.path_info.path)

def walk_files(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
26 changes: 22 additions & 4 deletions tests/unit/remote/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
from dvc.remote.base import RemoteMissingDepsError


class _CallableOrNone(object):
"""Helper for testing if object is callable() or None."""

def __eq__(self, other):
return other is None or callable(other)


CallableOrNone = _CallableOrNone()


class TestRemoteBASE(object):
REMOTE_CLS = RemoteBASE

Expand Down Expand Up @@ -82,7 +92,11 @@ def test_cache_exists(path_to_checksum, object_exists, traverse):
remote.cache_exists(checksums)
object_exists.assert_not_called()
traverse.assert_called_with(
frozenset(checksums), set(range(256)), None, None
frozenset(checksums),
set(range(256)),
256 * pow(16, remote.TRAVERSE_PREFIX_LEN),
None,
None,
)

# default traverse
Expand All @@ -105,8 +119,12 @@ def test_cache_exists(path_to_checksum, object_exists, traverse):
def test_cache_exists_traverse(path_to_checksum, list_cache_paths):
remote = RemoteBASE(None, {})
remote.path_info = PathInfo("foo")
remote._cache_exists_traverse({0}, set())
remote._cache_exists_traverse({0}, set(), 4096)
for i in range(1, 16):
list_cache_paths.assert_any_call(prefix="{:03x}".format(i))
list_cache_paths.assert_any_call(
prefix="{:03x}".format(i), progress_callback=CallableOrNone
)
for i in range(1, 256):
list_cache_paths.assert_any_call(prefix="{:02x}".format(i))
list_cache_paths.assert_any_call(
prefix="{:02x}".format(i), progress_callback=CallableOrNone
)

0 comments on commit acdc876

Please sign in to comment.