Skip to content

Commit

Permalink
remote: don't close sftp connections until parent ssh conn closed
Browse files Browse the repository at this point in the history
Since this is multiplexing over the same TCP connection no system
resources are withheld, so this should be a no-brainer optimization.

As a side-effect this should fix "open max - close all - try open -
Administratively prohobited" error @darabi experiencing in iterative#2280.
  • Loading branch information
Suor committed Jul 21, 2019
1 parent 624ccfc commit 31f1ded
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 55 deletions.
18 changes: 9 additions & 9 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def _exists(chunk_and_channel):
return ret

with self.ssh(path_infos[0]) as ssh:
with ssh.open_max_sftp_channels() as channels:
max_workers = len(channels)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
paths = [path_info.path for path_info in path_infos]
chunks = to_chunks(paths, num_chunks=max_workers)
chunks_and_channels = zip(chunks, channels)
outcome = executor.map(_exists, chunks_and_channels)
results = list(itertools.chain.from_iterable(outcome))
channels = ssh.open_max_sftp_channels()
max_workers = len(channels)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
paths = [path_info.path for path_info in path_infos]
chunks = to_chunks(paths, num_chunks=max_workers)
chunks_and_channels = zip(chunks, channels)
outcome = executor.map(_exists, chunks_and_channels)
results = list(itertools.chain.from_iterable(outcome))

return results

Expand Down
71 changes: 25 additions & 46 deletions dvc/remote/ssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import errno
import stat
from contextlib import contextmanager
from funcy import cached_property

try:
Expand Down Expand Up @@ -61,26 +60,22 @@ def __init__(self, host, *args, **kwargs):

self._ssh.connect(host, *args, **kwargs)
self._ssh.get_transport().set_keepalive(10)
self._sftp = None
self._sftp_alive = False
self._sftp_channels = []

def _sftp_connect(self):
if not self._sftp or not self._sftp_alive:
self._sftp = self._ssh.open_sftp()
self._sftp_alive = True
@property
def sftp(self):
if not self._sftp_channels:
self._sftp_channels = [self._ssh.open_sftp()]
return self._sftp_channels[0]

def close(self):
if self._sftp:
self._sftp.close()
self._sftp_alive = False

for sftp in self._sftp_channels:
sftp.close()
self._ssh.close()

def st_mode(self, path):
self._sftp_connect()

with ignore_file_not_found():
return self._sftp.stat(path).st_mode
return self.sftp.stat(path).st_mode

return 0

Expand All @@ -97,8 +92,6 @@ def islink(self, path):
return stat.S_ISLNK(self.st_mode(path))

def makedirs(self, path):
self._sftp_connect()

# Single stat call will say whether this is a dir, a file or a link
st_mode = self.st_mode(path)

Expand All @@ -117,7 +110,7 @@ def makedirs(self, path):

if tail:
try:
self._sftp.mkdir(path)
self.sftp.mkdir(path)
except IOError as e:
# Since paramiko errors are very vague we need to recheck
# whether it's because path already exists or something else
Expand All @@ -129,11 +122,8 @@ def walk(self, directory, topdown=True):
# used as a template.
#
# [1] https://github.com/python/cpython/blob/master/Lib/os.py

self._sftp_connect()

try:
dir_entries = self._sftp.listdir_attr(directory)
dir_entries = self.sftp.listdir_attr(directory)
except IOError as exc:
raise DvcException(
"couldn't get the '{}' remote directory files list".format(
Expand Down Expand Up @@ -169,7 +159,7 @@ def walk_files(self, directory):

def _remove_file(self, path):
with ignore_file_not_found():
self._sftp.remove(path)
self.sftp.remove(path)

def _remove_dir(self, path):
for root, dirs, files in self.walk(path, topdown=False):
Expand All @@ -181,52 +171,45 @@ def _remove_dir(self, path):
for dname in dirs:
path = posixpath.join(root, dname)
with ignore_file_not_found():
self._sftp.rmdir(dname)
self.sftp.rmdir(dname)

with ignore_file_not_found():
self._sftp.rmdir(path)
self.sftp.rmdir(path)

def remove(self, path):
self._sftp_connect()

if self.isdir(path):
self._remove_dir(path)
else:
self._remove_file(path)

def download(self, src, dest, no_progress_bar=False, progress_title=None):
self._sftp_connect()

if no_progress_bar:
self._sftp.get(src, dest)
self.sftp.get(src, dest)
else:
if not progress_title:
progress_title = os.path.basename(src)

self._sftp.get(src, dest, callback=create_cb(progress_title))
self.sftp.get(src, dest, callback=create_cb(progress_title))
progress.finish_target(progress_title)

def move(self, src, dst):
self.makedirs(posixpath.dirname(dst))
self._sftp_connect()
self._sftp.rename(src, dst)
self.sftp.rename(src, dst)

def upload(self, src, dest, no_progress_bar=False, progress_title=None):
self._sftp_connect()

self.makedirs(posixpath.dirname(dest))
tmp_file = tmp_fname(dest)

if no_progress_bar:
self._sftp.put(src, tmp_file)
self.sftp.put(src, tmp_file)
else:
if not progress_title:
progress_title = posixpath.basename(dest)

self._sftp.put(src, tmp_file, callback=create_cb(progress_title))
self.sftp.put(src, tmp_file, callback=create_cb(progress_title))
progress.finish_target(progress_title)

self._sftp.rename(tmp_file, dest)
self.sftp.rename(tmp_file, dest)

def execute(self, cmd):
stdin, stdout, stderr = self._ssh.exec_command(cmd)
Expand Down Expand Up @@ -307,18 +290,14 @@ def cp(self, src, dest):
self.makedirs(posixpath.dirname(dest))
self.execute("cp {} {}".format(src, dest))

@contextmanager
def open_max_sftp_channels(self):
try:
channels = []
# If there are more than 1 it means we've already opened max amount
if len(self._sftp_channels) <= 1:
while True:
try:
channels.append(self._ssh.open_sftp())
self._sftp_channels.append(self._ssh.open_sftp())
except paramiko.ssh_exception.ChannelException:
if not channels:
if not self._sftp_channels:
raise
break
yield channels
finally:
for channel in channels:
channel.close()
return self._sftp_channels

0 comments on commit 31f1ded

Please sign in to comment.