Skip to content

Commit

Permalink
remote: define command error on the base
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr. Outis committed Feb 9, 2019
1 parent 9210e66 commit ec77bfa
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 21 deletions.
10 changes: 6 additions & 4 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def __init__(self, msg):
super(DataCloudError, self).__init__("Data sync error: {}".format(msg))


class RemoteBaseCmdError(DvcException):
def __init__(self, cmd, ret, err):
m = "SSH command '{}' finished with non-zero return code '{}': {}"
super(RemoteBaseCmdError, self).__init__(m.format(cmd, ret, err))
class RemoteCmdError(DvcException):
def __init__(self, remote, cmd, ret, err):
super(RemoteCmdError, self).__init__(
"{remote} command '{cmd}' finished with non-zero return code"
" {ret}': {err}".format(remote=remote, cmd=cmd, ret=ret, err=err)
)


class RemoteMissingDepsError(DvcException):
Expand Down
12 changes: 4 additions & 8 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@

import dvc.logger as logger
from dvc.config import Config
from dvc.remote.base import RemoteBase, RemoteBaseCmdError
from dvc.remote.base import RemoteBase, RemoteCmdError
from dvc.utils import fix_env


class RemoteHDFSCmdError(RemoteBaseCmdError):
pass


class RemoteHDFS(RemoteBase):
scheme = "hdfs"
REGEX = r"^hdfs://((?P<user>.*)@)?.*$"
Expand Down Expand Up @@ -55,7 +51,7 @@ def hadoop_fs(self, cmd, user=None):
)
out, err = p.communicate()
if p.returncode != 0:
raise RemoteHDFSCmdError(cmd, p.returncode, err)
raise RemoteCmdError(self.scheme, cmd, p.returncode, err)
return out.decode("utf-8")

@staticmethod
Expand Down Expand Up @@ -123,7 +119,7 @@ def exists(self, path_info):
try:
self.hadoop_fs("test -e {}".format(path_info["path"]))
return True
except RemoteHDFSCmdError:
except RemoteCmdError:
return False

def upload(self, from_infos, to_infos, names=None):
Expand Down Expand Up @@ -177,7 +173,7 @@ def download(
def list_cache_paths(self):
try:
self.hadoop_fs("test -e {}".format(self.prefix))
except RemoteHDFSCmdError:
except RemoteCmdError:
return []

stdout = self.hadoop_fs("ls -R {}".format(self.prefix))
Expand Down
10 changes: 3 additions & 7 deletions dvc/remote/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dvc.progress import progress
from dvc.utils.compat import urlparse
from dvc.exceptions import DvcException
from dvc.remote.base import RemoteBase, RemoteBaseCmdError
from dvc.remote.base import RemoteBase, RemoteCmdError


def sizeof_fmt(num, suffix="B"):
Expand All @@ -43,10 +43,6 @@ def create_cb(name):
return lambda cur, tot: percent_cb(name, cur, tot)


class RemoteSSHCmdError(RemoteBaseCmdError):
pass


class RemoteSSH(RemoteBase):
scheme = "ssh"

Expand Down Expand Up @@ -134,7 +130,7 @@ def exists(self, path_info):
try:
self._exec(ssh, "test -e {}".format(path_info["path"]))
exists = True
except RemoteSSHCmdError:
except RemoteCmdError:
exists = False

return exists
Expand Down Expand Up @@ -184,7 +180,7 @@ def _exec(self, ssh, cmd):
ret = stdout.channel.recv_exit_status()
if ret != 0:
err = b"".join(stderr_chunks).decode("utf-8")
raise RemoteSSHCmdError(cmd, ret, err)
raise RemoteCmdError(self.scheme, cmd, ret, err)

return b"".join(stdout_chunks).decode("utf-8")

Expand Down
24 changes: 22 additions & 2 deletions tests/unit/remote/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import mock
from unittest import TestCase

from dvc.remote.base import RemoteBase, RemoteMissingDepsError
from dvc.remote import REMOTES, RemoteLOCAL
from dvc.remote import REMOTES, RemoteLOCAL, RemoteSSH, RemoteHDFS
from dvc.remote.base import RemoteBase, RemoteCmdError, RemoteMissingDepsError


class TestMissingDeps(TestCase):
Expand All @@ -13,3 +13,23 @@ def test(self):
with mock.patch.object(remote_class, "REQUIRES", REQUIRES):
with self.assertRaises(RemoteMissingDepsError):
remote_class(None, {})


class TestCmdError(TestCase):
def test(self):
for remote_class in [RemoteSSH, RemoteHDFS]:
project = None
config = {}

remote_name = remote_class.scheme
cmd = "sed 'hello'"
ret = "1"
err = "sed: expression #1, char 2: extra characters after command"

with mock.patch.object(
remote_class,
"remove",
side_effect=RemoteCmdError(remote_name, cmd, ret, err),
):
with self.assertRaises(RemoteCmdError):
remote_class(project, config).remove("file")

0 comments on commit ec77bfa

Please sign in to comment.