Skip to content

Commit

Permalink
Merge pull request iterative#2289 from Suor/exists-jobs
Browse files Browse the repository at this point in the history
remote: honor --jobs for status collection and other fixes
  • Loading branch information
efiop authored Jul 21, 2019
2 parents b6bc65f + 1f90b42 commit a1143c5
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 64 deletions.
4 changes: 2 additions & 2 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def changed_cache(self, checksum):
return self._changed_dir_cache(checksum)
return self.changed_cache_file(checksum)

def cache_exists(self, checksums):
def cache_exists(self, checksums, jobs=None):
"""Check if the given checksums are stored in the remote.
There are two ways of performing this check:
Expand Down Expand Up @@ -618,7 +618,7 @@ def exists_with_progress(chunks):
return self.batch_exists(chunks, callback=progress_callback)

if self.no_traverse and hasattr(self, "batch_exists"):
with ThreadPoolExecutor(max_workers=self.JOBS) as executor:
with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
path_infos = [self.checksum_to_path_info(x) for x in checksums]
chunks = to_chunks(path_infos, num_chunks=self.JOBS)
results = executor.map(exists_with_progress, chunks)
Expand Down
6 changes: 3 additions & 3 deletions dvc/remote/local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def move(self, from_info, to_info):
move(inp, tmp)
move(tmp, outp)

def cache_exists(self, md5s):
def cache_exists(self, md5s, jobs=None):
return [
checksum
for checksum in progress(md5s)
Expand Down Expand Up @@ -306,7 +306,7 @@ def status(
md5s = list(ret)

logger.info("Collecting information from local cache...")
local_exists = self.cache_exists(md5s)
local_exists = self.cache_exists(md5s, jobs=jobs)

# This is a performance optimization. We can safely assume that,
# if the resources that we want to fetch are already cached,
Expand All @@ -316,7 +316,7 @@ def status(
remote_exists = local_exists
else:
logger.info("Collecting information from remote cache...")
remote_exists = list(remote.cache_exists(md5s))
remote_exists = list(remote.cache_exists(md5s, jobs=jobs))

self._fill_statuses(ret, local_exists, remote_exists)

Expand Down
8 changes: 4 additions & 4 deletions dvc/remote/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


@contextmanager
def get_connection(*args, **kwargs):
pool = get_pool(*args, **kwargs)
def get_connection(conn_func, *args, **kwargs):
pool = get_pool(conn_func, *args, **kwargs)
conn = pool.get_connection()
try:
yield conn
Expand All @@ -17,8 +17,8 @@ def get_connection(*args, **kwargs):


@memoize
def get_pool(*args, **kwargs):
return Pool(*args, **kwargs)
def get_pool(conn_func, *args, **kwargs):
return Pool(conn_func, *args, **kwargs)


def close_pools():
Expand Down
18 changes: 9 additions & 9 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def _exists(chunk_and_channel):
return ret

with self.ssh(path_infos[0]) as ssh:
with ssh.open_max_sftp_channels() as channels:
max_workers = len(channels)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
paths = [path_info.path for path_info in path_infos]
chunks = to_chunks(paths, num_chunks=max_workers)
chunks_and_channels = zip(chunks, channels)
outcome = executor.map(_exists, chunks_and_channels)
results = list(itertools.chain.from_iterable(outcome))
channels = ssh.open_max_sftp_channels()
max_workers = len(channels)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
paths = [path_info.path for path_info in path_infos]
chunks = to_chunks(paths, num_chunks=max_workers)
chunks_and_channels = zip(chunks, channels)
outcome = executor.map(_exists, chunks_and_channels)
results = list(itertools.chain.from_iterable(outcome))

return results

Expand Down
71 changes: 25 additions & 46 deletions dvc/remote/ssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import errno
import stat
from contextlib import contextmanager
from funcy import cached_property

try:
Expand Down Expand Up @@ -61,26 +60,22 @@ def __init__(self, host, *args, **kwargs):

self._ssh.connect(host, *args, **kwargs)
self._ssh.get_transport().set_keepalive(10)
self._sftp = None
self._sftp_alive = False
self._sftp_channels = []

def _sftp_connect(self):
if not self._sftp or not self._sftp_alive:
self._sftp = self._ssh.open_sftp()
self._sftp_alive = True
@property
def sftp(self):
if not self._sftp_channels:
self._sftp_channels = [self._ssh.open_sftp()]
return self._sftp_channels[0]

def close(self):
if self._sftp:
self._sftp.close()
self._sftp_alive = False

for sftp in self._sftp_channels:
sftp.close()
self._ssh.close()

def st_mode(self, path):
self._sftp_connect()

with ignore_file_not_found():
return self._sftp.stat(path).st_mode
return self.sftp.stat(path).st_mode

return 0

Expand All @@ -97,8 +92,6 @@ def islink(self, path):
return stat.S_ISLNK(self.st_mode(path))

def makedirs(self, path):
self._sftp_connect()

# Single stat call will say whether this is a dir, a file or a link
st_mode = self.st_mode(path)

Expand All @@ -117,7 +110,7 @@ def makedirs(self, path):

if tail:
try:
self._sftp.mkdir(path)
self.sftp.mkdir(path)
except IOError as e:
# Since paramiko errors are very vague we need to recheck
# whether it's because path already exists or something else
Expand All @@ -129,11 +122,8 @@ def walk(self, directory, topdown=True):
# used as a template.
#
# [1] https://github.com/python/cpython/blob/master/Lib/os.py

self._sftp_connect()

try:
dir_entries = self._sftp.listdir_attr(directory)
dir_entries = self.sftp.listdir_attr(directory)
except IOError as exc:
raise DvcException(
"couldn't get the '{}' remote directory files list".format(
Expand Down Expand Up @@ -169,7 +159,7 @@ def walk_files(self, directory):

def _remove_file(self, path):
with ignore_file_not_found():
self._sftp.remove(path)
self.sftp.remove(path)

def _remove_dir(self, path):
for root, dirs, files in self.walk(path, topdown=False):
Expand All @@ -181,52 +171,45 @@ def _remove_dir(self, path):
for dname in dirs:
path = posixpath.join(root, dname)
with ignore_file_not_found():
self._sftp.rmdir(dname)
self.sftp.rmdir(dname)

with ignore_file_not_found():
self._sftp.rmdir(path)
self.sftp.rmdir(path)

def remove(self, path):
self._sftp_connect()

if self.isdir(path):
self._remove_dir(path)
else:
self._remove_file(path)

def download(self, src, dest, no_progress_bar=False, progress_title=None):
self._sftp_connect()

if no_progress_bar:
self._sftp.get(src, dest)
self.sftp.get(src, dest)
else:
if not progress_title:
progress_title = os.path.basename(src)

self._sftp.get(src, dest, callback=create_cb(progress_title))
self.sftp.get(src, dest, callback=create_cb(progress_title))
progress.finish_target(progress_title)

def move(self, src, dst):
self.makedirs(posixpath.dirname(dst))
self._sftp_connect()
self._sftp.rename(src, dst)
self.sftp.rename(src, dst)

def upload(self, src, dest, no_progress_bar=False, progress_title=None):
self._sftp_connect()

self.makedirs(posixpath.dirname(dest))
tmp_file = tmp_fname(dest)

if no_progress_bar:
self._sftp.put(src, tmp_file)
self.sftp.put(src, tmp_file)
else:
if not progress_title:
progress_title = posixpath.basename(dest)

self._sftp.put(src, tmp_file, callback=create_cb(progress_title))
self.sftp.put(src, tmp_file, callback=create_cb(progress_title))
progress.finish_target(progress_title)

self._sftp.rename(tmp_file, dest)
self.sftp.rename(tmp_file, dest)

def execute(self, cmd):
stdin, stdout, stderr = self._ssh.exec_command(cmd)
Expand Down Expand Up @@ -307,18 +290,14 @@ def cp(self, src, dest):
self.makedirs(posixpath.dirname(dest))
self.execute("cp {} {}".format(src, dest))

@contextmanager
def open_max_sftp_channels(self):
try:
channels = []
# If there are more than 1 it means we've already opened max amount
if len(self._sftp_channels) <= 1:
while True:
try:
channels.append(self._ssh.open_sftp())
self._sftp_channels.append(self._ssh.open_sftp())
except paramiko.ssh_exception.ChannelException:
if not channels:
if not self._sftp_channels:
raise
break
yield channels
finally:
for channel in channels:
channel.close()
return self._sftp_channels
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,11 @@ def erepo(repo_dir):
yield repo
finally:
repo.tearDown()


@pytest.fixture(scope="session", autouse=True)
def _close_pools():
from dvc.remote.pool import close_pools

yield
close_pools()

0 comments on commit a1143c5

Please sign in to comment.