Skip to content

Commit

Permalink
remote: hdfs: use pyarrow and connection pool
Browse files Browse the repository at this point in the history
Related to iterative#1629

Speeds up HDFS tests significantly. E.g. ExternalHDFS test goes from 600
sec to 190sec. test_open_external[HDFS] from 160sec to 3sec.

Signed-off-by: Ruslan Kuprieiev <[email protected]>
  • Loading branch information
efiop committed Jul 19, 2019
1 parent 1e9f1d0 commit 68bebe7
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 76 deletions.
4 changes: 2 additions & 2 deletions dvc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dvc.config import ConfigError
from dvc.analytics import Analytics
from dvc.exceptions import NotDvcRepoError, DvcParserError
from dvc.remote.ssh.pool import close_ssh_pools
from dvc.remote.pool import close_pools


logger = logging.getLogger("dvc")
Expand Down Expand Up @@ -56,7 +56,7 @@ def main(argv=None):
logger.setLevel(outerLogLevel)
# Python 2 fails to close these clean occasionally and users see
# weird error messages, so we do it manually
close_ssh_pools()
close_pools()

Analytics().send_cmd(cmd, args, ret)

Expand Down
2 changes: 1 addition & 1 deletion dvc/path_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __repr__(self):


class URLInfo(object):
DEFAULT_PORTS = {"http": 80, "https": 443, "ssh": 22}
DEFAULT_PORTS = {"http": 80, "https": 443, "ssh": 22, "hdfs": 0}

def __init__(self, url):
self.parsed = urlparse(url)
Expand Down
129 changes: 73 additions & 56 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
import logging
from subprocess import Popen, PIPE

try:
import pyarrow
except ImportError:
pyarrow = None

from dvc.config import Config
from dvc.scheme import Schemes

from dvc.remote.base import RemoteBASE, RemoteCmdError
from dvc.utils.compat import urlparse
from dvc.utils import fix_env, tmp_fname

from .pool import get_connection
from .base import RemoteBASE, RemoteCmdError

logger = logging.getLogger(__name__)

Expand All @@ -21,17 +28,39 @@ class RemoteHDFS(RemoteBASE):
scheme = Schemes.HDFS
REGEX = r"^hdfs://((?P<user>.*)@)?.*$"
PARAM_CHECKSUM = "checksum"
REQUIRES = {"pyarrow": pyarrow}

def __init__(self, repo, config):
super(RemoteHDFS, self).__init__(repo, config)
url = config.get(Config.SECTION_REMOTE_URL, "/")
self.path_info = self.path_cls(url)
self.path_info = None
url = config.get(Config.SECTION_REMOTE_URL)
if not url:
return

parsed = urlparse(url)

user = (
parsed.username
or config.get(Config.SECTION_REMOTE_USER)
or getpass.getuser()
)

self.user = self.path_info.user
if not self.user:
self.user = config.get(
Config.SECTION_REMOTE_USER, getpass.getuser()
)
self.path_info = self.path_cls.from_parts(
scheme=self.scheme,
host=parsed.hostname,
user=user,
port=parsed.port,
path=parsed.path,
)

@staticmethod
def hdfs(path_info):
return get_connection(
pyarrow.hdfs.connect,
path_info.host,
path_info.port,
user=path_info.user,
)

def hadoop_fs(self, cmd, user=None):
cmd = "hadoop fs -" + cmd
Expand Down Expand Up @@ -65,6 +94,7 @@ def _group(regex, s, gname):
return match.group(gname)

def get_file_checksum(self, path_info):
# NOTE: pyarrow doesn't support checksum, so we need to use hadoop
regex = r".*\t.*\t(?P<checksum>.*)"
stdout = self.hadoop_fs(
"checksum {}".format(path_info.path), user=path_info.user
Expand All @@ -73,68 +103,55 @@ def get_file_checksum(self, path_info):

def copy(self, from_info, to_info, **_kwargs):
dname = posixpath.dirname(to_info.path)
self.hadoop_fs("mkdir -p {}".format(dname), user=to_info.user)
self.hadoop_fs(
"cp -f {} {}".format(from_info.path, to_info.path),
user=to_info.user,
)

def rm(self, path_info):
self.hadoop_fs("rm -f {}".format(path_info.path), user=path_info.user)
with self.hdfs(to_info) as hdfs:
hdfs.mkdir(dname)
# NOTE: this is how `hadoop fs -cp` works too: it copies through
# your local machine.
with hdfs.open(from_info.path, "rb") as from_fobj:
with hdfs.open(to_info.path, "wb") as to_fobj:
to_fobj.upload(from_fobj)

def remove(self, path_info):
if path_info.scheme != "hdfs":
raise NotImplementedError

assert path_info.path

logger.debug("Removing {}".format(path_info.path))

self.rm(path_info)
if self.exists(path_info):
logger.debug("Removing {}".format(path_info.path))
with self.hdfs(path_info) as hdfs:
hdfs.rm(path_info.path)

def exists(self, path_info):
assert not isinstance(path_info, list)
assert path_info.scheme == "hdfs"

try:
self.hadoop_fs("test -e {}".format(path_info.path))
return True
except RemoteCmdError:
return False
with self.hdfs(path_info) as hdfs:
return hdfs.exists(path_info.path)

def _upload(self, from_file, to_info, **_kwargs):
self.hadoop_fs(
"mkdir -p {}".format(to_info.parent.url), user=to_info.user
)

tmp_file = tmp_fname(to_info.url)

self.hadoop_fs(
"copyFromLocal {} {}".format(from_file, tmp_file),
user=to_info.user,
)

self.hadoop_fs(
"mv {} {}".format(tmp_file, to_info.url), user=to_info.user
)
with self.hdfs(to_info) as hdfs:
hdfs.mkdir(posixpath.dirname(to_info.path))
tmp_file = tmp_fname(to_info.path)
with open(from_file, "rb") as fobj:
hdfs.upload(tmp_file, fobj)
hdfs.rename(tmp_file, to_info.path)

def _download(self, from_info, to_file, **_kwargs):
self.hadoop_fs(
"copyToLocal {} {}".format(from_info.url, to_file),
user=from_info.user,
)
with self.hdfs(from_info) as hdfs:
with open(to_file, "wb+") as fobj:
hdfs.download(from_info.path, fobj)

def list_cache_paths(self):
try:
self.hadoop_fs("test -e {}".format(self.path_info.url))
except RemoteCmdError:
if not self.exists(self.path_info):
return []

stdout = self.hadoop_fs("ls -R {}".format(self.path_info.url))
lines = stdout.split("\n")
flist = []
for line in lines:
if not line.startswith("-"):
continue
flist.append(line.split()[-1])
return flist
files = []
dirs = [self.path_info.path]

with self.hdfs(self.path_info) as hdfs:
while dirs:
for entry in hdfs.ls(dirs.pop(), detail=True):
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
files.append(urlparse(entry["name"]).path)

return files
23 changes: 11 additions & 12 deletions dvc/remote/ssh/pool.py → dvc/remote/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from contextlib import contextmanager
from funcy import memoize

from .connection import SSHConnection


@contextmanager
def ssh_connection(*conn_args, **conn_kwargs):
pool = get_ssh_pool(*conn_args, **conn_kwargs)
def get_connection(*args, **kwargs):
pool = get_pool(*args, **kwargs)
conn = pool.get_connection()
try:
yield conn
Expand All @@ -19,18 +17,19 @@ def ssh_connection(*conn_args, **conn_kwargs):


@memoize
def get_ssh_pool(*conn_args, **conn_kwargs):
return SSHPool(conn_args, conn_kwargs)
def get_pool(*args, **kwargs):
return Pool(*args, **kwargs)


def close_ssh_pools():
for pool in get_ssh_pool.memory.values():
def close_pools():
for pool in get_pool.memory.values():
pool.close()
get_ssh_pool.memory.clear()
get_pool.memory.clear()


class SSHPool(object):
def __init__(self, conn_args, conn_kwargs):
class Pool(object):
def __init__(self, conn_func, *conn_args, **conn_kwargs):
self._conn_func = conn_func
self._conn_args = conn_args
self._conn_kwargs = conn_kwargs
self._conns = deque()
Expand All @@ -48,7 +47,7 @@ def get_connection(self):
try:
return self._conns.popleft()
except IndexError:
return SSHConnection(*self._conn_args, **self._conn_kwargs)
return self._conn_func(*self._conn_args, **self._conn_kwargs)

def release(self, conn):
if self._closed:
Expand Down
7 changes: 5 additions & 2 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from dvc.utils.compat import urlparse, StringIO
from dvc.remote.base import RemoteBASE
from dvc.scheme import Schemes
from .pool import ssh_connection
from dvc.remote.pool import get_connection

from .connection import SSHConnection


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,7 +123,8 @@ def ssh(self, path_info):
)
)

return ssh_connection(
return get_connection(
SSHConnection,
host,
username=user,
port=port,
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ def run(self):
azure = ["azure-storage-blob==2.0.1"]
oss = ["oss2==2.6.1"]
ssh = ["paramiko>=2.5.0"]
hdfs = ["pyarrow>=0.14.0"]
all_remotes = gs + s3 + azure + ssh + oss

if os.name != "nt" or sys.version_info[0] != 2:
# NOTE: there are no pyarrow wheels for python2 on windows
all_remotes += hdfs

# Extra dependecies to run tests
tests_requirements = [
"PyInstaller==3.4",
Expand Down Expand Up @@ -125,6 +130,7 @@ def run(self):
"azure": azure,
"oss": oss,
"ssh": ssh,
"hdfs": hdfs,
# NOTE: https://github.com/inveniosoftware/troubleshooting/issues/1
":python_version=='2.7'": ["futures", "pathlib2"],
"tests": tests_requirements,
Expand Down
12 changes: 11 additions & 1 deletion tests/func/test_output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import sys

import pytest

from dvc.stage import Stage
Expand All @@ -8,7 +11,14 @@
("s3://bucket/path", "s3"),
("gs://bucket/path", "gs"),
("ssh://example.com:/dir/path", "ssh"),
("hdfs://example.com/dir/path", "hdfs"),
pytest.param(
"hdfs://example.com/dir/path",
"hdfs",
marks=pytest.mark.skipif(
sys.version_info[0] == 2 and os.name == "nt",
reason="Not supported for python 2 on Windows.",
),
),
("path/to/file", "local"),
("path\\to\\file", "local"),
("file", "local"),
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/dependency/test_hdfs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import os
import sys
import pytest

from dvc.dependency.hdfs import DependencyHDFS

from tests.unit.dependency.test_local import TestDependencyLOCAL


@pytest.mark.skipif(
sys.version_info[0] == 2 and os.name == "nt",
reason="Not supported for python 2 on Windows.",
)
class TestDependencyHDFS(TestDependencyLOCAL):
def _get_cls(self):
return DependencyHDFS
8 changes: 8 additions & 0 deletions tests/unit/output/test_hdfs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import os
import sys
import pytest

from dvc.output.hdfs import OutputHDFS

from tests.unit.output.test_local import TestOutputLOCAL


@pytest.mark.skipif(
sys.version_info[0] == 2 and os.name == "nt",
reason="Not supported for python 2 on Windows.",
)
class TestOutputHDFS(TestOutputLOCAL):
def _get_cls(self):
return OutputHDFS
7 changes: 5 additions & 2 deletions tests/unit/remote/ssh/test_pool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest

from dvc.remote.ssh.pool import ssh_connection
from dvc.remote.pool import get_connection
from dvc.remote.ssh.connection import SSHConnection


def test_doesnt_swallow_errors(ssh_server):
class MyError(Exception):
pass

with pytest.raises(MyError), ssh_connection(**ssh_server.test_creds):
with pytest.raises(MyError), get_connection(
SSHConnection, **ssh_server.test_creds
):
raise MyError

0 comments on commit 68bebe7

Please sign in to comment.