From 51a8c782aeac0b758adef9a161ae97b17503cb01 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Thu, 18 Jun 2020 01:36:13 +0300 Subject: [PATCH] tests: create remote fixtures (#4063) * tests: use remote fixtures * tests: use remote fixtures for data cloud tests * tests: treat remotes as tmp dirs --- setup.cfg | 2 +- setup.py | 1 + tests/conftest.py | 35 +- tests/dir_helpers.py | 39 +- tests/func/test_api.py | 128 ++-- tests/func/test_data_cloud.py | 738 +++++------------------ tests/func/test_external_repo.py | 8 +- tests/func/test_gc.py | 8 +- tests/func/test_get.py | 7 +- tests/func/test_import.py | 7 +- tests/func/test_remote.py | 13 +- tests/func/test_repro.py | 6 +- tests/func/test_run_cache.py | 5 +- tests/func/test_tree.py | 7 +- tests/remotes.py | 295 --------- tests/remotes/__init__.py | 35 ++ tests/remotes/azure.py | 39 ++ tests/remotes/base.py | 23 + tests/remotes/gdrive.py | 55 ++ tests/remotes/gs.py | 90 +++ tests/remotes/hdfs.py | 55 ++ tests/remotes/http.py | 31 + tests/remotes/local.py | 26 + tests/remotes/oss.py | 45 ++ tests/remotes/s3.py | 68 +++ tests/remotes/ssh.py | 115 ++++ tests/{ => remotes}/user.key | 0 tests/{ => remotes}/user.key.pub | 0 tests/unit/remote/ssh/test_connection.py | 52 +- tests/unit/remote/ssh/test_pool.py | 14 +- tests/unit/remote/ssh/test_ssh.py | 14 +- 31 files changed, 857 insertions(+), 1104 deletions(-) delete mode 100644 tests/remotes.py create mode 100644 tests/remotes/__init__.py create mode 100644 tests/remotes/azure.py create mode 100644 tests/remotes/base.py create mode 100644 tests/remotes/gdrive.py create mode 100644 tests/remotes/gs.py create mode 100644 tests/remotes/hdfs.py create mode 100644 tests/remotes/http.py create mode 100644 tests/remotes/local.py create mode 100644 tests/remotes/oss.py create mode 100644 tests/remotes/s3.py create mode 100644 tests/remotes/ssh.py rename tests/{ => remotes}/user.key (100%) rename tests/{ => remotes}/user.key.pub (100%) diff --git a/setup.cfg b/setup.cfg index fbcef890d5..2d6e2132f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ select=B,C,E,F,W,T4,B9 [isort] include_trailing_comma=true known_first_party=dvc,tests -known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,google,grandalf,mock,mockssh,moto,nanotime,networkx,packaging,paramiko,pathspec,pytest,requests,ruamel,setuptools,shortuuid,tqdm,voluptuous,yaml,zc +known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,google,grandalf,mock,moto,nanotime,networkx,packaging,paramiko,pathspec,pytest,requests,ruamel,setuptools,shortuuid,tqdm,voluptuous,yaml,zc line_length=79 force_grid_wrap=0 use_parentheses=True diff --git a/setup.py b/setup.py index 87dad025b7..8570cdf263 100644 --- a/setup.py +++ b/setup.py @@ -106,6 +106,7 @@ def run(self): "pytest-cov>=2.6.1", "pytest-xdist>=1.26.1", "pytest-mock==1.11.2", + "pytest-lazy-fixture", "flaky>=3.5.3", "mock>=3.0.0", "xmltodict>=0.11.0", diff --git a/tests/conftest.py b/tests/conftest.py index 5816ab53f0..54ba074e9d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,9 @@ import os -import mockssh import pytest -from dvc.remote.ssh.connection import SSHConnection -from tests.utils.httpd import PushRequestHandler, StaticFileServer - from .dir_helpers import * # noqa +from .remotes import * # noqa # Prevent updater and analytics from running their processes os.environ["DVC_TEST"] = "true" @@ -28,39 +25,9 @@ def reset_loglevel(request, caplog): yield -here = os.path.abspath(os.path.dirname(__file__)) - -user = "user" -key_path = os.path.join(here, f"{user}.key") - - -@pytest.fixture(scope="session") -def ssh_server(): - users = {user: key_path} - with mockssh.Server(users) as s: - s.test_creds = { - "host": s.host, - "port": s.port, - "username": user, - "key_filename": key_path, - } - yield s - - -@pytest.fixture -def ssh(ssh_server): - yield SSHConnection(**ssh_server.test_creds) - - @pytest.fixture(scope="session", autouse=True) def _close_pools(): from dvc.remote.pool import close_pools yield close_pools() - - -@pytest.fixture -def http_server(tmp_dir): - with StaticFileServer(handler_class=PushRequestHandler) as httpd: - yield httpd diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index e54f31ab7d..8480e0ec6e 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -63,7 +63,6 @@ "run_head", "erepo_dir", "git_dir", - "setup_remote", "git_init", ] @@ -177,6 +176,27 @@ def scm_add(self, filenames, commit=None): if commit: self.scm.commit(commit) + def add_remote( + self, *, url=None, config=None, name="upstream", default=True + ): + self._require("dvc") + + assert bool(url) ^ bool(config) + + if url: + config = {"url": url} + + with self.dvc.config.edit() as conf: + conf["remote"][name] = config + if default: + conf["core"]["remote"] = name + + if hasattr(self, "scm"): + self.scm.add(self.dvc.config.files["repo"]) + self.scm.commit(f"add '{name}' remote") + + return url or config["url"] + # contexts @contextmanager def chdir(self): @@ -315,20 +335,3 @@ def git_dir(make_tmp_dir): path = make_tmp_dir("git-erepo", scm=True) path.scm.commit("init repo") return path - - -@pytest.fixture -def setup_remote(make_tmp_dir): - def create(repo, url=None, name="upstream", default=True): - if not url: - url = os.fspath(make_tmp_dir("local_remote")) - with repo.config.edit() as conf: - conf["remote"][name] = {"url": url} - if default: - conf["core"]["remote"] = name - - repo.scm.add(repo.config.files["repo"]) - repo.scm.commit(f"add '{name}' remote") - return url - - return create diff --git a/tests/func/test_api.py b/tests/func/test_api.py index 3b4042b458..aee81b56d0 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -5,93 +5,42 @@ from dvc import api from dvc.api import UrlNotDvcRepoError from dvc.exceptions import FileMissingError -from dvc.main import main from dvc.path_info import URLInfo from dvc.utils.fs import remove -from tests.remotes import ( - GCP, - HDFS, - OSS, - S3, - SSH, - TEST_REMOTE, - Azure, - GDrive, - Local, -) - -remote_params = [S3, GCP, Azure, GDrive, OSS, SSH, HDFS] -all_remote_params = [Local] + remote_params - - -@pytest.fixture -def remote_url(request): - if not request.param.should_test(): - raise pytest.skip() - return request.param.get_url() - - -def run_dvc(*argv): - assert main(argv) == 0 - - -def ensure_dir(dvc, url): - if url.startswith("gdrive://"): - GDrive.create_dir(dvc, url) - run_dvc( - "remote", - "modify", - TEST_REMOTE, - "gdrive_service_account_email", - "test", - ) - run_dvc( - "remote", - "modify", - TEST_REMOTE, - "gdrive_service_account_p12_file_path", - "test.p12", - ) - run_dvc( - "remote", - "modify", - TEST_REMOTE, - "gdrive_use_service_account", - "True", - ) - - -def ensure_dir_scm(dvc, url): - if url.startswith("gdrive://"): - GDrive.create_dir(dvc, url) - with dvc.config.edit() as conf: - conf["remote"][TEST_REMOTE].update( - gdrive_service_account_email="test", - gdrive_service_account_p12_file_path="test.p12", - gdrive_use_service_account=True, - ) - dvc.scm.add(dvc.config.files["repo"]) - dvc.scm.commit(f"modify '{TEST_REMOTE}' remote") - - -@pytest.mark.parametrize("remote_url", remote_params, indirect=True) -def test_get_url(tmp_dir, dvc, remote_url): - run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) + +cloud_names = [ + "s3", + "gs", + "azure", + "gdrive", + "oss", + "ssh", + "hdfs", + "http", +] +clouds = [pytest.lazy_fixture(cloud) for cloud in cloud_names] +all_clouds = [pytest.lazy_fixture("local")] + clouds +remotes = [pytest.lazy_fixture(f"{cloud}_remote") for cloud in cloud_names] +all_remotes = [pytest.lazy_fixture("local_remote")] + remotes + + +@pytest.mark.parametrize("remote", remotes) +def test_get_url(tmp_dir, dvc, request, remote): tmp_dir.dvc_gen("foo", "foo") - expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8" + expected_url = URLInfo(remote.url) / "ac/bd18db4cc2f85cedef654fccc4a4d8" assert api.get_url("foo") == expected_url -@pytest.mark.parametrize("remote_url", remote_params, indirect=True) -def test_get_url_external(erepo_dir, remote_url, setup_remote): - setup_remote(erepo_dir.dvc, url=remote_url) +@pytest.mark.parametrize("cloud", clouds) +def test_get_url_external(erepo_dir, cloud): + erepo_dir.add_remote(config=cloud.config) with erepo_dir.chdir(): erepo_dir.dvc_gen("foo", "foo", commit="add foo") # Using file url to force clone to tmp repo repo_url = f"file://{erepo_dir}" - expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8" + expected_url = URLInfo(cloud.url) / "ac/bd18db4cc2f85cedef654fccc4a4d8" assert api.get_url("foo", repo=repo_url) == expected_url @@ -105,12 +54,10 @@ def test_get_url_requires_dvc(tmp_dir, scm): api.get_url("foo", repo=f"file://{tmp_dir}") -@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) -def test_open(remote_url, tmp_dir, dvc): - run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) - ensure_dir(dvc, remote_url) +@pytest.mark.parametrize("remote", all_remotes) +def test_open(tmp_dir, dvc, remote): tmp_dir.dvc_gen("foo", "foo-text") - run_dvc("push") + dvc.push() # Remove cache to force download remove(dvc.cache.local.cache_dir) @@ -119,10 +66,9 @@ def test_open(remote_url, tmp_dir, dvc): assert fd.read() == "foo-text" -@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) -def test_open_external(remote_url, erepo_dir, setup_remote): - setup_remote(erepo_dir.dvc, url=remote_url) - ensure_dir_scm(erepo_dir.dvc, remote_url) +@pytest.mark.parametrize("cloud", clouds) +def test_open_external(erepo_dir, cloud): + erepo_dir.add_remote(config=cloud.config) with erepo_dir.chdir(): erepo_dir.dvc_gen("version", "master", commit="add version") @@ -144,12 +90,10 @@ def test_open_external(remote_url, erepo_dir, setup_remote): assert api.read("version", repo=repo_url, rev="branch") == "branchver" -@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) -def test_open_granular(remote_url, tmp_dir, dvc): - run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) - ensure_dir(dvc, remote_url) +@pytest.mark.parametrize("remote", all_remotes) +def test_open_granular(tmp_dir, dvc, remote): tmp_dir.dvc_gen({"dir": {"foo": "foo-text"}}) - run_dvc("push") + dvc.push() # Remove cache to force download remove(dvc.cache.local.cache_dir) @@ -158,11 +102,9 @@ def test_open_granular(remote_url, tmp_dir, dvc): assert fd.read() == "foo-text" -@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) -def test_missing(remote_url, tmp_dir, dvc): +@pytest.mark.parametrize("remote", all_remotes) +def test_missing(tmp_dir, dvc, remote): tmp_dir.dvc_gen("foo", "foo") - run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) - ensure_dir(dvc, remote_url) # Remove cache to make foo missing remove(dvc.cache.local.cache_dir) diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 8400b846ba..6ffba48192 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -1,497 +1,163 @@ -import copy import logging import os import shutil -import uuid -from unittest import SkipTest import pytest from flaky.flaky_decorator import flaky from dvc.cache import NamedCache -from dvc.data_cloud import DataCloud from dvc.external_repo import clean_repos from dvc.main import main -from dvc.remote.azure import AzureRemoteTree from dvc.remote.base import STATUS_DELETED, STATUS_NEW, STATUS_OK -from dvc.remote.gdrive import GDriveRemoteTree -from dvc.remote.gs import GSRemoteTree -from dvc.remote.hdfs import HDFSRemoteTree -from dvc.remote.http import HTTPRemoteTree from dvc.remote.local import LocalRemoteTree -from dvc.remote.oss import OSSRemoteTree -from dvc.remote.s3 import S3RemoteTree -from dvc.remote.ssh import SSHRemoteTree from dvc.stage.exceptions import StageNotFound from dvc.utils import file_md5 from dvc.utils.fs import remove from dvc.utils.yaml import dump_yaml, load_yaml from tests.basic_env import TestDvc -from tests.remotes import ( - GCP, - HDFS, - HTTP, - OSS, - S3, - TEST_CONFIG, - TEST_GCP_CREDS_FILE, - TEST_REMOTE, - Azure, - GDrive, - Local, - SSHMocked, -) - - -class TestDataCloud(TestDvc): - def _test_cloud(self, config, cl): - self.dvc.config = config - cloud = DataCloud(self.dvc) - self.assertIsInstance(cloud.get_remote().tree, cl) - - def test(self): - config = copy.deepcopy(TEST_CONFIG) - - clist = [ - ("s3://mybucket/", S3RemoteTree), - ("gs://mybucket/", GSRemoteTree), - ("ssh://user@localhost:/", SSHRemoteTree), - ("http://localhost:8000/", HTTPRemoteTree), - ("azure://ContainerName=mybucket;conn_string;", AzureRemoteTree), - ("oss://mybucket/", OSSRemoteTree), - (TestDvc.mkdtemp(), LocalRemoteTree), - ] - - for scheme, cl in clist: - remote_url = scheme + str(uuid.uuid4()) - config["remote"][TEST_REMOTE]["url"] = remote_url - self._test_cloud(config, cl) - - -class TestDataCloudBase(TestDvc): - def _get_cloud_class(self): - return None - - @staticmethod - def should_test(): - return False - - @staticmethod - def get_url(): - return NotImplementedError - - def _get_keyfile(self): - return None - - def _ensure_should_run(self): - if not self.should_test(): - raise SkipTest(f"Test {self.__class__.__name__} is disabled") - - def _setup_cloud(self): - self._ensure_should_run() - - repo = self.get_url() - keyfile = self._get_keyfile() - - config = copy.deepcopy(TEST_CONFIG) - config["remote"][TEST_REMOTE] = {"url": repo, "keyfile": keyfile} - self.dvc.config = config - self.cloud = DataCloud(self.dvc) - - self.assertIsInstance( - self.cloud.get_remote().tree, self._get_cloud_class() - ) - - def _test_cloud(self): - self._setup_cloud() - - stages = self.dvc.add(self.FOO) - self.assertEqual(len(stages), 1) - stage = stages[0] - self.assertTrue(stage is not None) - out = stage.outs[0] - cache = out.cache_path - md5 = out.checksum - info = out.get_used_cache() - - stages = self.dvc.add(self.DATA_DIR) - self.assertEqual(len(stages), 1) - stage_dir = stages[0] - self.assertTrue(stage_dir is not None) - out_dir = stage_dir.outs[0] - cache_dir = out_dir.cache_path - name_dir = str(out_dir) - md5_dir = out_dir.checksum - info_dir = NamedCache.make(out_dir.scheme, md5_dir, name_dir) - - with self.cloud.repo.state: - # Check status - status = self.cloud.status(info, show_checksums=True) - expected = {md5: {"name": md5, "status": STATUS_NEW}} - self.assertEqual(status, expected) - - status_dir = self.cloud.status(info_dir, show_checksums=True) - expected = {md5_dir: {"name": md5_dir, "status": STATUS_NEW}} - self.assertEqual(status_dir, expected) - - # Push and check status - self.cloud.push(info) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - - self.cloud.push(info_dir) - self.assertTrue(os.path.isfile(cache_dir)) - - status = self.cloud.status(info, show_checksums=True) - expected = {md5: {"name": md5, "status": STATUS_OK}} - self.assertEqual(status, expected) - - status_dir = self.cloud.status(info_dir, show_checksums=True) - expected = {md5_dir: {"name": md5_dir, "status": STATUS_OK}} - self.assertEqual(status_dir, expected) - - # Remove and check status - remove(self.dvc.cache.local.cache_dir) - - status = self.cloud.status(info, show_checksums=True) - expected = {md5: {"name": md5, "status": STATUS_DELETED}} - self.assertEqual(status, expected) - - status_dir = self.cloud.status(info_dir, show_checksums=True) - expected = {md5_dir: {"name": md5_dir, "status": STATUS_DELETED}} - self.assertEqual(status_dir, expected) - - # Pull and check status - self.cloud.pull(info) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - with open(cache) as fd: - self.assertEqual(fd.read(), self.FOO_CONTENTS) - - self.cloud.pull(info_dir) - self.assertTrue(os.path.isfile(cache_dir)) - - status = self.cloud.status(info, show_checksums=True) - expected = {md5: {"name": md5, "status": STATUS_OK}} - self.assertEqual(status, expected) - - status_dir = self.cloud.status(info_dir, show_checksums=True) - expected = {md5_dir: {"name": md5_dir, "status": STATUS_OK}} - self.assertTrue(status_dir, expected) - - def test(self): - self._ensure_should_run() - self._test_cloud() - -class TestS3Remote(S3, TestDataCloudBase): - def _get_cloud_class(self): - return S3RemoteTree +from .test_api import all_remotes -class TestGDriveRemote(GDrive, TestDataCloudBase): - def _setup_cloud(self): - self._ensure_should_run() +@pytest.mark.parametrize("remote", all_remotes) +def test_cloud(tmp_dir, dvc, remote): + (stage,) = tmp_dir.dvc_gen("foo", "foo") + out = stage.outs[0] + cache = out.cache_path + md5 = out.checksum + info = out.get_used_cache() - url = self.get_url() - self.create_dir(self.dvc, url) - - config = copy.deepcopy(TEST_CONFIG) - config["remote"][TEST_REMOTE] = { - "url": url, - "gdrive_service_account_email": "test", - "gdrive_service_account_p12_file_path": "test.p12", - "gdrive_use_service_account": True, + (stage_dir,) = tmp_dir.dvc_gen( + { + "data_dir": { + "data_sub_dir": {"data_sub": "data_sub"}, + "data": "data", + } } - - self.dvc.config = config - self.cloud = DataCloud(self.dvc) - remote = self.cloud.get_remote() - self.assertIsInstance(remote.tree, self._get_cloud_class()) - - def _get_cloud_class(self): - return GDriveRemoteTree - - -class TestGSRemote(GCP, TestDataCloudBase): - def _setup_cloud(self): - self._ensure_should_run() - - repo = self.get_url() - - config = copy.deepcopy(TEST_CONFIG) - config["remote"][TEST_REMOTE] = { - "url": repo, - "credentialpath": TEST_GCP_CREDS_FILE, - } - self.dvc.config = config - self.cloud = DataCloud(self.dvc) - - self.assertIsInstance( - self.cloud.get_remote().tree, self._get_cloud_class() - ) - - def _get_cloud_class(self): - return GSRemoteTree - - -class TestAzureRemote(Azure, TestDataCloudBase): - def _get_cloud_class(self): - return AzureRemoteTree - - -class TestOSSRemote(OSS, TestDataCloudBase): - def _get_cloud_class(self): - return OSSRemoteTree - - -class TestLocalRemote(Local, TestDataCloudBase): - def _get_cloud_class(self): - return LocalRemoteTree - - -@pytest.mark.usefixtures("ssh_server") -class TestSSHRemoteMocked(SSHMocked, TestDataCloudBase): - @pytest.fixture(autouse=True) - def setup_method_fixture(self, request, ssh_server): - self.ssh_server = ssh_server - self.method_name = request.function.__name__ - - def _setup_cloud(self): - self._ensure_should_run() - - repo = self.get_url() - keyfile = self._get_keyfile() - - self._get_cloud_class().CAN_TRAVERSE = False - config = copy.deepcopy(TEST_CONFIG) - config["remote"][TEST_REMOTE] = { - "url": repo, - "keyfile": keyfile, - } - self.dvc.config = config - self.cloud = DataCloud(self.dvc) - - self.assertIsInstance( - self.cloud.get_remote().tree, self._get_cloud_class() - ) - - def get_url(self): - user = self.ssh_server.test_creds["username"] - return super().get_url(user, self.ssh_server.port) - - def _get_keyfile(self): - return self.ssh_server.test_creds["key_filename"] - - def _get_cloud_class(self): - return SSHRemoteTree - - -class TestHDFSRemote(HDFS, TestDataCloudBase): - def _get_cloud_class(self): - return HDFSRemoteTree - - -@pytest.mark.usefixtures("http_server") -class TestHTTPRemote(HTTP, TestDataCloudBase): - @pytest.fixture(autouse=True) - def setup_method_fixture(self, request, http_server): - self.http_server = http_server - self.method_name = request.function.__name__ - - def get_url(self): - return super().get_url(self.http_server.server_port) - - def _get_cloud_class(self): - return HTTPRemoteTree - - -class TestDataCloudCLIBase(TestDvc): - def main(self, args): - ret = main(args) - self.assertEqual(ret, 0) - - @staticmethod - def should_test(): - return False - - @staticmethod - def get_url(): - raise NotImplementedError - - def _setup_cloud(self): - pass - - def _test_cloud(self, remote=None): - self._setup_cloud() - - args = ["-v", "-j", "2"] - if remote: - args += ["-r", remote] - else: - args += [] - - stages = self.dvc.add(self.FOO) - self.assertEqual(len(stages), 1) - stage = stages[0] - self.assertTrue(stage is not None) - cache = stage.outs[0].cache_path - - stages = self.dvc.add(self.DATA_DIR) - self.assertEqual(len(stages), 1) - stage_dir = stages[0] - self.assertTrue(stage_dir is not None) - cache_dir = stage_dir.outs[0].cache_path - - # FIXME check status output - - self.main(["push"] + args) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - self.assertTrue(os.path.isfile(cache_dir)) - - remove(self.dvc.cache.local.cache_dir) - - self.main(["fetch"] + args) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - self.assertTrue(os.path.isfile(cache_dir)) - - self.main(["pull"] + args) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - self.assertTrue(os.path.isfile(cache_dir)) - self.assertTrue(os.path.isfile(self.FOO)) - self.assertTrue(os.path.isdir(self.DATA_DIR)) - + ) + out_dir = stage_dir.outs[0] + cache_dir = out_dir.cache_path + name_dir = str(out_dir) + md5_dir = out_dir.checksum + info_dir = NamedCache.make(out_dir.scheme, md5_dir, name_dir) + + with dvc.state: + # Check status + status = dvc.cloud.status(info, show_checksums=True) + expected = {md5: {"name": md5, "status": STATUS_NEW}} + assert status == expected + + status_dir = dvc.cloud.status(info_dir, show_checksums=True) + expected = {md5_dir: {"name": md5_dir, "status": STATUS_NEW}} + assert status_dir == expected + + # Push and check status + dvc.cloud.push(info) + assert os.path.exists(cache) + assert os.path.isfile(cache) + + dvc.cloud.push(info_dir) + assert os.path.isfile(cache_dir) + + status = dvc.cloud.status(info, show_checksums=True) + expected = {md5: {"name": md5, "status": STATUS_OK}} + assert status == expected + + status_dir = dvc.cloud.status(info_dir, show_checksums=True) + expected = {md5_dir: {"name": md5_dir, "status": STATUS_OK}} + assert status_dir == expected + + # Remove and check status + remove(dvc.cache.local.cache_dir) + + status = dvc.cloud.status(info, show_checksums=True) + expected = {md5: {"name": md5, "status": STATUS_DELETED}} + assert status == expected + + status_dir = dvc.cloud.status(info_dir, show_checksums=True) + expected = {md5_dir: {"name": md5_dir, "status": STATUS_DELETED}} + assert status_dir == expected + + # Pull and check status + dvc.cloud.pull(info) + assert os.path.exists(cache) + assert os.path.isfile(cache) with open(cache) as fd: - self.assertEqual(fd.read(), self.FOO_CONTENTS) - self.assertTrue(os.path.isfile(cache_dir)) - - # NOTE: check if remote gc works correctly on directories - self.main(["gc", "-cw", "-f"] + args) - shutil.move( - self.dvc.cache.local.cache_dir, - self.dvc.cache.local.cache_dir + ".back", - ) - - self.main(["fetch"] + args) - - self.main(["pull", "-f"] + args) - self.assertTrue(os.path.exists(cache)) - self.assertTrue(os.path.isfile(cache)) - self.assertTrue(os.path.isfile(cache_dir)) - self.assertTrue(os.path.isfile(self.FOO)) - self.assertTrue(os.path.isdir(self.DATA_DIR)) - - def _test(self): - pass - - def test(self): - if not self.should_test(): - raise SkipTest(f"Test {self.__class__.__name__} is disabled") - self._test() - - -class TestLocalRemoteCLI(Local, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() - - self.main(["remote", "add", TEST_REMOTE, url]) - - self._test_cloud(TEST_REMOTE) - + assert fd.read() == "foo" -class TestRemoteHDFSCLI(HDFS, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() + dvc.cloud.pull(info_dir) + assert os.path.isfile(cache_dir) - self.main(["remote", "add", TEST_REMOTE, url]) + status = dvc.cloud.status(info, show_checksums=True) + expected = {md5: {"name": md5, "status": STATUS_OK}} + assert status == expected - self._test_cloud(TEST_REMOTE) + status_dir = dvc.cloud.status(info_dir, show_checksums=True) + expected = {md5_dir: {"name": md5_dir, "status": STATUS_OK}} + assert status_dir == expected -class TestS3RemoteCLI(S3, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() +@pytest.mark.parametrize("remote", all_remotes) +def test_cloud_cli(tmp_dir, dvc, remote): + args = ["-v", "-j", "2"] - self.main(["remote", "add", TEST_REMOTE, url]) + (stage,) = tmp_dir.dvc_gen("foo", "foo") + cache = stage.outs[0].cache_path - self._test_cloud(TEST_REMOTE) - - -class TestGDriveRemoteCLI(GDrive, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() - - self.create_dir(self.dvc, url) - - self.main(["remote", "add", TEST_REMOTE, url]) - self.main( - [ - "remote", - "modify", - TEST_REMOTE, - "gdrive_service_account_email", - "test", - ] - ) - self.main( - [ - "remote", - "modify", - TEST_REMOTE, - "gdrive_service_account_p12_file_path", - "test.p12", - ] - ) - self.main( - [ - "remote", - "modify", - TEST_REMOTE, - "gdrive_use_service_account", - "True", - ] - ) - - self._test_cloud(TEST_REMOTE) - - -class TestGSRemoteCLI(GCP, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() + (stage_dir,) = tmp_dir.dvc_gen( + { + "data_dir": { + "data_sub_dir": {"data_sub": "data_sub"}, + "data": "data", + } + } + ) + assert stage_dir is not None + cache_dir = stage_dir.outs[0].cache_path - self.main(["remote", "add", TEST_REMOTE, url]) - self.main( - [ - "remote", - "modify", - TEST_REMOTE, - "credentialpath", - TEST_GCP_CREDS_FILE, - ] - ) + # FIXME check status output - self._test_cloud(TEST_REMOTE) + assert main(["push"] + args) == 0 + assert os.path.exists(cache) + assert os.path.isfile(cache) + assert os.path.isfile(cache_dir) + remove(dvc.cache.local.cache_dir) -class TestAzureRemoteCLI(Azure, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() + assert main(["fetch"] + args) == 0 + assert os.path.exists(cache) + assert os.path.isfile(cache) + assert os.path.isfile(cache_dir) - self.main(["remote", "add", TEST_REMOTE, url]) + assert main(["pull"] + args) == 0 + assert os.path.exists(cache) + assert os.path.isfile(cache) + assert os.path.isfile(cache_dir) + assert os.path.isfile("foo") + assert os.path.isdir("data_dir") - self._test_cloud(TEST_REMOTE) + with open(cache) as fd: + assert fd.read() == "foo" + assert os.path.isfile(cache_dir) + # NOTE: http doesn't support gc yet + if remote.url.startswith("http"): + return -class TestOSSRemoteCLI(OSS, TestDataCloudCLIBase): - def _test(self): - url = self.get_url() + # NOTE: check if remote gc works correctly on directories + assert main(["gc", "-cw", "-f"] + args) == 0 + shutil.move( + dvc.cache.local.cache_dir, dvc.cache.local.cache_dir + ".back", + ) - self.main(["remote", "add", TEST_REMOTE, url]) + assert main(["fetch"] + args) == 0 - self._test_cloud(TEST_REMOTE) + assert main(["pull", "-f"] + args) == 0 + assert os.path.exists(cache) + assert os.path.isfile(cache) + assert os.path.isfile(cache_dir) + assert os.path.isfile("foo") + assert os.path.isdir("data_dir") class TestDataCloudErrorCLI(TestDvc): @@ -507,124 +173,30 @@ def test_error(self): self.main_fail(["fetch", f]) -class TestWarnOnOutdatedStage(TestDvc): - def main(self, args): - ret = main(args) - self.assertEqual(ret, 0) - - def _test(self): - url = Local.get_url() - self.main(["remote", "add", "-d", TEST_REMOTE, url]) - - stage = self.dvc.run( - outs=["bar"], cmd="echo bar > bar", single_stage=True - ) - self.main(["push"]) - - stage_file_path = stage.relpath - content = load_yaml(stage_file_path) - del content["outs"][0]["md5"] - dump_yaml(stage_file_path, content) - - with self._caplog.at_level(logging.WARNING, logger="dvc"): - self._caplog.clear() - self.main(["status", "-c"]) - expected_warning = ( - "Output 'bar'(stage: 'bar.dvc') is missing version info. " - "Cache for it will not be collected. " - "Use `dvc repro` to get your pipeline up to date." - ) - - assert expected_warning in self._caplog.text - - def test(self): - self._test() - - -class TestRecursiveSyncOperations(Local, TestDataCloudBase): - def main(self, args): - ret = main(args) - self.assertEqual(ret, 0) +def test_warn_on_outdated_stage(tmp_dir, dvc, local_remote, caplog): + stage = dvc.run(outs=["bar"], cmd="echo bar > bar", single_stage=True) + assert main(["push"]) == 0 - def _get_cloud_class(self): - return LocalRemoteTree + stage_file_path = stage.relpath + content = load_yaml(stage_file_path) + del content["outs"][0]["md5"] + dump_yaml(stage_file_path, content) - def _prepare_repo(self): - remote = self.cloud.get_remote() - self.main( - ["remote", "add", "-d", TEST_REMOTE, remote.path_info.fspath] + with caplog.at_level(logging.WARNING, logger="dvc"): + caplog.clear() + assert main(["status", "-c"]) == 0 + expected_warning = ( + "Output 'bar'(stage: 'bar.dvc') is missing version info. " + "Cache for it will not be collected. " + "Use `dvc repro` to get your pipeline up to date." ) - self.dvc.add(self.DATA) - self.dvc.add(self.DATA_SUB) - - def _remove_local_data_files(self): - os.remove(self.DATA) - os.remove(self.DATA_SUB) - - def _test_recursive_pull(self): - self._remove_local_data_files() - self._clear_local_cache() - - self.assertFalse(os.path.exists(self.DATA)) - self.assertFalse(os.path.exists(self.DATA_SUB)) - - self.main(["pull", "-R", self.DATA_DIR]) - - self.assertTrue(os.path.exists(self.DATA)) - self.assertTrue(os.path.exists(self.DATA_SUB)) - - def _clear_local_cache(self): - remove(self.dvc.cache.local.cache_dir) - - def _test_recursive_fetch(self, data_md5, data_sub_md5): - self._clear_local_cache() + assert expected_warning in caplog.text - local_cache = self.dvc.cache.local - local_cache_data_path = local_cache.hash_to_path_info(data_md5) - local_cache_data_sub_path = local_cache.hash_to_path_info(data_sub_md5) - self.assertFalse(os.path.exists(local_cache_data_path)) - self.assertFalse(os.path.exists(local_cache_data_sub_path)) - - self.main(["fetch", "-R", self.DATA_DIR]) - - self.assertTrue(os.path.exists(local_cache_data_path)) - self.assertTrue(os.path.exists(local_cache_data_sub_path)) - - def _test_recursive_push(self, data_md5, data_sub_md5): - remote = self.cloud.get_remote() - cloud_data_path = remote.hash_to_path_info(data_md5) - cloud_data_sub_path = remote.hash_to_path_info(data_sub_md5) - - self.assertFalse(os.path.exists(cloud_data_path)) - self.assertFalse(os.path.exists(cloud_data_sub_path)) - - self.main(["push", "-R", self.DATA_DIR]) - - self.assertTrue(os.path.exists(cloud_data_path)) - self.assertTrue(os.path.exists(cloud_data_sub_path)) - - def test(self): - self._setup_cloud() - self._prepare_repo() - - data_md5 = file_md5(self.DATA)[0] - data_sub_md5 = file_md5(self.DATA_SUB)[0] - - self._test_recursive_push(data_md5, data_sub_md5) - - self._test_recursive_fetch(data_md5, data_sub_md5) - - self._test_recursive_pull() - - -def test_hash_recalculation(mocker, dvc, tmp_dir): +def test_hash_recalculation(mocker, dvc, tmp_dir, local_remote): tmp_dir.gen({"foo": "foo"}) test_get_file_hash = mocker.spy(LocalRemoteTree, "get_file_hash") - url = Local.get_url() - ret = main(["remote", "add", "-d", TEST_REMOTE, url]) - assert ret == 0 ret = main(["config", "cache.type", "hardlink"]) assert ret == 0 ret = main(["add", "foo"]) @@ -687,10 +259,8 @@ def test(self): def test_verify_hashes( - tmp_dir, scm, dvc, mocker, tmp_path_factory, setup_remote + tmp_dir, scm, dvc, mocker, tmp_path_factory, local_remote ): - - setup_remote(dvc, name="upstream") tmp_dir.dvc_gen({"file": "file1 content"}, commit="add file") tmp_dir.dvc_gen({"dir": {"subfile": "file2 content"}}, commit="add dir") dvc.push() @@ -782,8 +352,7 @@ def recurse_list_dir(d): ] -def test_dvc_pull_pipeline_stages(tmp_dir, dvc, run_copy, setup_remote): - setup_remote(dvc) +def test_dvc_pull_pipeline_stages(tmp_dir, dvc, run_copy, local_remote): (stage0,) = tmp_dir.dvc_gen("foo", "foo") stage1 = run_copy("foo", "bar", single_stage=True) stage2 = run_copy("bar", "foobar", name="copy-bar-foobar") @@ -814,8 +383,8 @@ def test_dvc_pull_pipeline_stages(tmp_dir, dvc, run_copy, setup_remote): assert set(stats["added"]) == set(outs) -def test_pipeline_file_target_ops(tmp_dir, dvc, run_copy, setup_remote): - remote_path = setup_remote(dvc) +def test_pipeline_file_target_ops(tmp_dir, dvc, run_copy, local_remote): + path = local_remote.url tmp_dir.dvc_gen("foo", "foo") run_copy("foo", "bar", single_stage=True) @@ -834,7 +403,7 @@ def test_pipeline_file_target_ops(tmp_dir, dvc, run_copy, setup_remote): outs = ["foo", "bar", "lorem", "ipsum", "baz", "lorem2"] # each one's a copy of other, hence 3 - assert len(recurse_list_dir(remote_path)) == 3 + assert len(recurse_list_dir(path)) == 3 clean(outs, dvc) assert set(dvc.pull(["dvc.yaml"])["added"]) == {"lorem2", "baz"} @@ -845,13 +414,13 @@ def test_pipeline_file_target_ops(tmp_dir, dvc, run_copy, setup_remote): # clean everything in remote and push from tests.dir_helpers import TmpDir - clean(TmpDir(remote_path).iterdir()) + clean(TmpDir(path).iterdir()) dvc.push(["dvc.yaml:copy-ipsum-baz"]) - assert len(recurse_list_dir(remote_path)) == 1 + assert len(recurse_list_dir(path)) == 1 - clean(TmpDir(remote_path).iterdir()) + clean(TmpDir(path).iterdir()) dvc.push(["dvc.yaml"]) - assert len(recurse_list_dir(remote_path)) == 2 + assert len(recurse_list_dir(path)) == 2 with pytest.raises(StageNotFound): dvc.push(["dvc.yaml:StageThatDoesNotExist"]) @@ -868,8 +437,7 @@ def test_pipeline_file_target_ops(tmp_dir, dvc, run_copy, setup_remote): ({}, "Everything is up to date"), ], ) -def test_push_stats(tmp_dir, dvc, fs, msg, caplog, setup_remote): - setup_remote(dvc) +def test_push_stats(tmp_dir, dvc, fs, msg, caplog, local_remote): tmp_dir.dvc_gen(fs) caplog.clear() @@ -886,8 +454,7 @@ def test_push_stats(tmp_dir, dvc, fs, msg, caplog, setup_remote): ({}, "Everything is up to date."), ], ) -def test_fetch_stats(tmp_dir, dvc, fs, msg, caplog, setup_remote): - setup_remote(dvc) +def test_fetch_stats(tmp_dir, dvc, fs, msg, caplog, local_remote): tmp_dir.dvc_gen(fs) dvc.push() clean(list(fs.keys()), dvc) @@ -897,8 +464,7 @@ def test_fetch_stats(tmp_dir, dvc, fs, msg, caplog, setup_remote): assert msg in caplog.text -def test_pull_stats(tmp_dir, dvc, caplog, setup_remote): - setup_remote(dvc) +def test_pull_stats(tmp_dir, dvc, caplog, local_remote): tmp_dir.dvc_gen({"foo": "foo", "bar": "bar"}) dvc.push() clean(["foo", "bar"], dvc) @@ -919,8 +485,7 @@ def test_pull_stats(tmp_dir, dvc, caplog, setup_remote): @pytest.mark.parametrize( "key,expected", [("all_tags", 2), ("all_branches", 3), ("all_commits", 3)] ) -def test_push_pull_all(tmp_dir, scm, dvc, setup_remote, key, expected): - setup_remote(dvc) +def test_push_pull_all(tmp_dir, scm, dvc, local_remote, key, expected): tmp_dir.dvc_gen({"foo": "foo"}, commit="first") scm.tag("v1") dvc.remove("foo.dvc") @@ -936,13 +501,12 @@ def test_push_pull_all(tmp_dir, scm, dvc, setup_remote, key, expected): assert dvc.pull(**{key: True})["fetched"] == expected -def test_push_pull_fetch_pipeline_stages(tmp_dir, dvc, run_copy, setup_remote): - remote_path = setup_remote(dvc) +def test_push_pull_fetch_pipeline_stages(tmp_dir, dvc, run_copy, local_remote): tmp_dir.dvc_gen("foo", "foo") run_copy("foo", "bar", no_commit=True, name="copy-foo-bar") dvc.push("copy-foo-bar") - assert len(recurse_list_dir(remote_path)) == 1 + assert len(recurse_list_dir(local_remote.url)) == 1 # pushing everything so as we can check pull/fetch only downloads # from specified targets dvc.push() diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index a2f8ef098c..f868cf472c 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -43,8 +43,8 @@ def test_source_change(erepo_dir): assert old_rev != new_rev -def test_cache_reused(erepo_dir, mocker, setup_remote): - setup_remote(erepo_dir.dvc) +def test_cache_reused(erepo_dir, mocker, local_cloud): + erepo_dir.add_remote(config=local_cloud.config) with erepo_dir.chdir(): erepo_dir.dvc_gen("file", "text", commit="add file") erepo_dir.dvc.push() @@ -99,7 +99,7 @@ def test_pull_subdir_file(tmp_dir, erepo_dir): assert dest.read_text() == "contents" -def test_relative_remote(erepo_dir, tmp_dir, setup_remote): +def test_relative_remote(erepo_dir, tmp_dir): # these steps reproduce the script on this issue: # https://github.com/iterative/dvc/issues/2756 with erepo_dir.chdir(): @@ -107,7 +107,7 @@ def test_relative_remote(erepo_dir, tmp_dir, setup_remote): upstream_dir = tmp_dir upstream_url = relpath(upstream_dir, erepo_dir) - setup_remote(erepo_dir.dvc, url=upstream_url, name="upstream") + erepo_dir.add_remote(url=upstream_url) erepo_dir.dvc.push() diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index 847fc1529e..28c1d31542 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -240,9 +240,9 @@ def test_gc_without_workspace_raises_error(tmp_dir, dvc): dvc.gc(force=True, workspace=False) -def test_gc_cloud_with_or_without_specifier(tmp_dir, erepo_dir, setup_remote): +def test_gc_cloud_with_or_without_specifier(tmp_dir, erepo_dir, local_cloud): + erepo_dir.add_remote(config=local_cloud.config) dvc = erepo_dir.dvc - setup_remote(dvc) from dvc.exceptions import InvalidArgumentError with pytest.raises(InvalidArgumentError): @@ -297,9 +297,7 @@ def test_gc_with_possible_args_positive(tmp_dir, dvc): assert main(["gc", "-vf", flag]) == 0 -def test_gc_cloud_positive(tmp_dir, dvc, tmp_path_factory, setup_remote): - setup_remote(dvc) - +def test_gc_cloud_positive(tmp_dir, dvc, tmp_path_factory, local_remote): for flag in ["-cw", "-ca", "-cT", "-caT", "-cwT"]: assert main(["gc", "-vf", flag]) == 0 diff --git a/tests/func/test_get.py b/tests/func/test_get.py index 20ff9f9d66..1fab440c47 100644 --- a/tests/func/test_get.py +++ b/tests/func/test_get.py @@ -201,8 +201,8 @@ def test_get_file_from_dir(tmp_dir, erepo_dir): assert (tmp_dir / "X").read_text() == "foo" -def test_get_url_positive(tmp_dir, erepo_dir, caplog, setup_remote): - setup_remote(erepo_dir.dvc) +def test_get_url_positive(tmp_dir, erepo_dir, caplog, local_cloud): + erepo_dir.add_remote(config=local_cloud.config) with erepo_dir.chdir(): erepo_dir.dvc_gen("foo", "foo") erepo_dir.dvc.push() @@ -238,11 +238,10 @@ def test_get_url_git_only_repo(tmp_dir, scm, caplog): def test_get_pipeline_tracked_outs( - tmp_dir, dvc, scm, git_dir, run_copy, setup_remote + tmp_dir, dvc, scm, git_dir, run_copy, local_remote ): from dvc.dvcfile import PIPELINE_FILE, PIPELINE_LOCK - setup_remote(dvc) tmp_dir.gen("foo", "foo") run_copy("foo", "bar", name="copy-foo-bar") dvc.push() diff --git a/tests/func/test_import.py b/tests/func/test_import.py index a3b4328d06..b97d9ca102 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -307,13 +307,13 @@ def test_pull_no_rev_lock(erepo_dir, tmp_dir, dvc): def test_import_from_bare_git_repo( - tmp_dir, make_tmp_dir, erepo_dir, setup_remote + tmp_dir, make_tmp_dir, erepo_dir, local_cloud ): import git git.Repo.init(os.fspath(tmp_dir), bare=True) - setup_remote(erepo_dir.dvc) + erepo_dir.add_remote(config=local_cloud.config) with erepo_dir.chdir(): erepo_dir.dvc_gen({"foo": "foo"}, commit="initial") erepo_dir.dvc.push() @@ -327,11 +327,10 @@ def test_import_from_bare_git_repo( def test_import_pipeline_tracked_outs( - tmp_dir, dvc, scm, erepo_dir, run_copy, setup_remote + tmp_dir, dvc, scm, erepo_dir, run_copy, local_remote ): from dvc.dvcfile import PIPELINE_FILE, PIPELINE_LOCK - setup_remote(dvc) tmp_dir.gen("foo", "foo") run_copy("foo", "bar", name="copy-foo-bar") dvc.push() diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index a424224cd9..1e61e6fd39 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -173,9 +173,7 @@ def test_dir_hash_should_be_key_order_agnostic(tmp_dir, dvc): assert hash1 == hash2 -def test_partial_push_n_pull(tmp_dir, dvc, tmp_path_factory, setup_remote): - setup_remote(dvc, name="upstream") - +def test_partial_push_n_pull(tmp_dir, dvc, tmp_path_factory, local_remote): foo = tmp_dir.dvc_gen({"foo": "foo content"})[0].outs[0] bar = tmp_dir.dvc_gen({"bar": "bar content"})[0].outs[0] baz = tmp_dir.dvc_gen({"baz": {"foo": "baz content"}})[0].outs[0] @@ -211,9 +209,8 @@ def unreliable_upload(self, from_file, to_info, name=None, **kwargs): def test_raise_on_too_many_open_files( - tmp_dir, dvc, tmp_path_factory, mocker, setup_remote + tmp_dir, dvc, tmp_path_factory, mocker, local_remote ): - setup_remote(dvc) tmp_dir.dvc_gen({"file": "file content"}) mocker.patch.object( @@ -246,8 +243,7 @@ def test_external_dir_resource_on_no_cache(tmp_dir, dvc, tmp_path_factory): ) -def test_push_order(tmp_dir, dvc, tmp_path_factory, mocker, setup_remote): - setup_remote(dvc) +def test_push_order(tmp_dir, dvc, tmp_path_factory, mocker, local_remote): tmp_dir.dvc_gen({"foo": {"bar": "bar content"}}) tmp_dir.dvc_gen({"baz": "baz content"}) @@ -387,8 +383,7 @@ def test_remote_default(dvc): assert local_config["core"]["remote"] == new_name -def test_protect_local_remote(tmp_dir, dvc, setup_remote): - setup_remote(dvc, name="upstream") +def test_protect_local_remote(tmp_dir, dvc, local_remote): (stage,) = tmp_dir.dvc_gen("file", "file content") dvc.push() diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 75d62ece71..28a523cd8c 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1737,12 +1737,14 @@ def test_downstream(dvc): reason="external output scenario is not supported on Windows", ) def test_ssh_dir_out(tmp_dir, dvc, ssh_server): + from tests.remotes.ssh import TEST_SSH_USER, TEST_SSH_KEY_PATH + tmp_dir.gen({"foo": "foo content"}) # Set up remote and cache - user = ssh_server.test_creds["username"] + user = TEST_SSH_USER port = ssh_server.port - keyfile = ssh_server.test_creds["key_filename"] + keyfile = TEST_SSH_KEY_PATH remote_url = SSHMocked.get_url(user, port) assert main(["remote", "add", "upstream", remote_url]) == 0 diff --git a/tests/func/test_run_cache.py b/tests/func/test_run_cache.py index 7b510b5a3a..5d995304e6 100644 --- a/tests/func/test_run_cache.py +++ b/tests/func/test_run_cache.py @@ -8,12 +8,11 @@ def _recurse_count_files(path): return len([os.path.join(r, f) for r, _, fs in os.walk(path) for f in fs]) -def test_push_pull(tmp_dir, dvc, erepo_dir, run_copy, setup_remote): +def test_push_pull(tmp_dir, dvc, erepo_dir, run_copy, local_remote): tmp_dir.gen("foo", "foo") run_copy("foo", "bar", name="copy-foo-bar") - url = setup_remote(dvc) assert dvc.push(run_cache=True) == 2 - setup_remote(erepo_dir.dvc, url) + erepo_dir.add_remote(config=local_remote.config) with erepo_dir.chdir(): assert not os.path.exists(erepo_dir.dvc.stage_cache.cache_dir) assert erepo_dir.dvc.pull(run_cache=True)["fetched"] == 2 diff --git a/tests/func/test_tree.py b/tests/func/test_tree.py index b08b014e45..bdfe7c71b2 100644 --- a/tests/func/test_tree.py +++ b/tests/func/test_tree.py @@ -179,8 +179,7 @@ def test_branch(self): ) -def test_repotree_walk_fetch(tmp_dir, dvc, scm, setup_remote): - setup_remote(dvc) +def test_repotree_walk_fetch(tmp_dir, dvc, scm, local_remote): out = tmp_dir.dvc_gen({"dir": {"foo": "foo"}}, commit="init")[0].outs[0] dvc.push() remove(dvc.cache.local.cache_dir) @@ -196,12 +195,12 @@ def test_repotree_walk_fetch(tmp_dir, dvc, scm, setup_remote): assert os.path.exists(dvc.cache.local.hash_to_path_info(hash_)) -def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, setup_remote): +def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, local_cloud): with erepo_dir.chdir(): erepo_dir.gen({"dir": {"subdir": {"foo": "foo"}, "bar": "bar"}}) erepo_dir.dvc_add("dir/subdir", commit="subdir") erepo_dir.scm_add("dir", commit="dir") - setup_remote(erepo_dir.dvc) + erepo_dir.add_remote(config=local_cloud.config) erepo_dir.dvc.push() # test only cares that either fetch or stream are set so that DVC dirs are diff --git a/tests/remotes.py b/tests/remotes.py deleted file mode 100644 index c1000f5fac..0000000000 --- a/tests/remotes.py +++ /dev/null @@ -1,295 +0,0 @@ -import getpass -import os -import platform -import uuid -from contextlib import contextmanager -from subprocess import CalledProcessError, Popen, check_output - -from moto.s3 import mock_s3 - -from dvc.remote.base import Remote -from dvc.remote.gdrive import GDriveRemoteTree -from dvc.remote.gs import GSRemoteTree -from dvc.remote.s3 import S3RemoteTree -from dvc.utils import env2bool -from tests.basic_env import TestDvc - -TEST_REMOTE = "upstream" -TEST_CONFIG = { - "cache": {}, - "core": {"remote": TEST_REMOTE}, - "remote": {TEST_REMOTE: {"url": ""}}, -} - -TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-temp") -TEST_GCP_REPO_BUCKET = os.environ.get("DVC_TEST_GCP_REPO_BUCKET", "dvc-test") -TEST_GDRIVE_REPO_BUCKET = "root" -TEST_OSS_REPO_BUCKET = "dvc-test" - -TEST_GCP_CREDS_FILE = os.path.abspath( - os.environ.get( - "GOOGLE_APPLICATION_CREDENTIALS", - os.path.join("scripts", "ci", "gcp-creds.json"), - ) -) -# Ensure that absolute path is used -os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE - -always_test = staticmethod(lambda: True) - - -class Local: - should_test = always_test - - @staticmethod - def get_storagepath(): - return TestDvc.mkdtemp() - - @staticmethod - def get_url(): - return Local.get_storagepath() - - -class S3: - @staticmethod - def should_test(): - do_test = env2bool("DVC_TEST_AWS", undefined=None) - if do_test is not None: - return do_test - - if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv( - "AWS_SECRET_ACCESS_KEY" - ): - return True - - return False - - @staticmethod - def get_storagepath(): - return TEST_AWS_REPO_BUCKET + "/" + str(uuid.uuid4()) - - @staticmethod - def get_url(): - return "s3://" + S3.get_storagepath() - - -class S3Mocked(S3): - should_test = always_test - - @classmethod - @contextmanager - def remote(cls, repo): - with mock_s3(): - yield Remote(S3RemoteTree(repo, {"url": cls.get_url()})) - - @staticmethod - def put_objects(remote, objects): - s3 = remote.tree.s3 - bucket = remote.path_info.bucket - s3.create_bucket(Bucket=bucket) - for key, body in objects.items(): - s3.put_object( - Bucket=bucket, Key=(remote.path_info / key).path, Body=body - ) - - -class GCP: - @staticmethod - def should_test(): - do_test = env2bool("DVC_TEST_GCP", undefined=None) - if do_test is not None: - return do_test - - if not os.path.exists(TEST_GCP_CREDS_FILE): - return False - - try: - check_output( - [ - "gcloud", - "auth", - "activate-service-account", - "--key-file", - TEST_GCP_CREDS_FILE, - ] - ) - except (CalledProcessError, OSError): - return False - return True - - @staticmethod - def get_storagepath(): - return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) - - @staticmethod - def get_url(): - return "gs://" + GCP.get_storagepath() - - @classmethod - @contextmanager - def remote(cls, repo): - yield Remote(GSRemoteTree(repo, {"url": cls.get_url()})) - - @staticmethod - def put_objects(remote, objects): - client = remote.tree.gs - bucket = client.get_bucket(remote.path_info.bucket) - for key, body in objects.items(): - bucket.blob((remote.path_info / key).path).upload_from_string(body) - - -class GDrive: - @staticmethod - def should_test(): - return os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) is not None - - @staticmethod - def create_dir(dvc, url): - config = { - "url": url, - "gdrive_service_account_email": "test", - "gdrive_service_account_p12_file_path": "test.p12", - "gdrive_use_service_account": True, - } - tree = GDriveRemoteTree(dvc, config) - tree._gdrive_create_dir("root", tree.path_info.path) - - @staticmethod - def get_storagepath(): - return TEST_GDRIVE_REPO_BUCKET + "/" + str(uuid.uuid4()) - - @staticmethod - def get_url(): - # NOTE: `get_url` should always return new random url - return "gdrive://" + GDrive.get_storagepath() - - -class Azure: - @staticmethod - def should_test(): - do_test = env2bool("DVC_TEST_AZURE", undefined=None) - if do_test is not None: - return do_test - - return os.getenv("AZURE_STORAGE_CONTAINER_NAME") and os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" - ) - - @staticmethod - def get_url(): - container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME") - assert container_name is not None - return "azure://{}/{}".format(container_name, str(uuid.uuid4())) - - -class OSS: - @staticmethod - def should_test(): - do_test = env2bool("DVC_TEST_OSS", undefined=None) - if do_test is not None: - return do_test - - return ( - os.getenv("OSS_ENDPOINT") - and os.getenv("OSS_ACCESS_KEY_ID") - and os.getenv("OSS_ACCESS_KEY_SECRET") - ) - - @staticmethod - def get_storagepath(): - return f"{TEST_OSS_REPO_BUCKET}/{uuid.uuid4()}" - - @staticmethod - def get_url(): - return f"oss://{OSS.get_storagepath()}" - - -class SSH: - @staticmethod - def should_test(): - do_test = env2bool("DVC_TEST_SSH", undefined=None) - if do_test is not None: - return do_test - - # FIXME: enable on windows - if os.name == "nt": - return False - - try: - check_output(["ssh", "-o", "BatchMode=yes", "127.0.0.1", "ls"]) - except (CalledProcessError, OSError): - return False - - return True - - @staticmethod - def get_url(): - return "ssh://{}@127.0.0.1:22{}".format( - getpass.getuser(), Local.get_storagepath() - ) - - -class SSHMocked: - should_test = always_test - - @staticmethod - def get_url(user, port): - path = Local.get_storagepath() - if os.name == "nt": - # NOTE: On Windows Local.get_storagepath() will return an - # ntpath that looks something like `C:\some\path`, which is not - # compatible with SFTP paths [1], so we need to convert it to - # a proper posixpath. - # To do that, we should construct a posixpath that would be - # relative to the server's root. - # In our case our ssh server is running with `c:/` as a root, - # and our URL format requires absolute paths, so the - # resulting path would look like `/some/path`. - # - # [1]https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6 - drive, path = os.path.splitdrive(path) - assert drive.lower() == "c:" - path = path.replace("\\", "/") - url = f"ssh://{user}@127.0.0.1:{port}{path}" - return url - - -class HDFS: - @staticmethod - def should_test(): - if platform.system() != "Linux": - return False - - try: - check_output( - ["hadoop", "version"], - shell=True, - executable=os.getenv("SHELL"), - ) - except (CalledProcessError, OSError): - return False - - p = Popen( - "hadoop fs -ls hdfs://127.0.0.1/", - shell=True, - executable=os.getenv("SHELL"), - ) - p.communicate() - if p.returncode != 0: - return False - - return True - - @staticmethod - def get_url(): - return "hdfs://{}@127.0.0.1{}".format( - getpass.getuser(), Local.get_storagepath() - ) - - -class HTTP: - should_test = always_test - - @staticmethod - def get_url(port): - return f"http://127.0.0.1:{port}" diff --git a/tests/remotes/__init__.py b/tests/remotes/__init__.py new file mode 100644 index 0000000000..6b23448e1e --- /dev/null +++ b/tests/remotes/__init__.py @@ -0,0 +1,35 @@ +TEST_REMOTE = "upstream" +TEST_CONFIG = { + "cache": {}, + "core": {"remote": TEST_REMOTE}, + "remote": {TEST_REMOTE: {"url": ""}}, +} + +from .azure import Azure, azure, azure_remote # noqa: F401 +from .hdfs import HDFS, hdfs, hdfs_remote # noqa: F401 +from .http import HTTP, http, http_remote, http_server # noqa: F401 +from .local import Local, local_cloud, local_remote # noqa: F401 +from .oss import OSS, TEST_OSS_REPO_BUCKET, oss, oss_remote # noqa: F401 +from .s3 import S3, TEST_AWS_REPO_BUCKET, S3Mocked, s3, s3_remote # noqa: F401 + +from .gdrive import ( # noqa: F401; noqa: F401 + TEST_GDRIVE_REPO_BUCKET, + GDrive, + gdrive, + gdrive_remote, +) +from .gs import ( # noqa: F401; noqa: F401 + GCP, + TEST_GCP_CREDS_FILE, + TEST_GCP_REPO_BUCKET, + gs, + gs_remote, +) +from .ssh import ( # noqa: F401; noqa: F401 + SSH, + SSHMocked, + ssh, + ssh_connection, + ssh_remote, + ssh_server, +) diff --git a/tests/remotes/azure.py b/tests/remotes/azure.py new file mode 100644 index 0000000000..2f15ca79a6 --- /dev/null +++ b/tests/remotes/azure.py @@ -0,0 +1,39 @@ +import os +import uuid + +import pytest + +from dvc.utils import env2bool + +from .base import Base + + +class Azure(Base): + @staticmethod + def should_test(): + do_test = env2bool("DVC_TEST_AZURE", undefined=None) + if do_test is not None: + return do_test + + return os.getenv("AZURE_STORAGE_CONTAINER_NAME") and os.getenv( + "AZURE_STORAGE_CONNECTION_STRING" + ) + + @staticmethod + def get_url(): + container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME") + assert container_name is not None + return "azure://{}/{}".format(container_name, str(uuid.uuid4())) + + +@pytest.fixture +def azure(): + if not Azure.should_test(): + pytest.skip("no azure running") + yield Azure() + + +@pytest.fixture +def azure_remote(tmp_dir, dvc, azure): + tmp_dir.add_remote(config=azure.config) + yield azure diff --git a/tests/remotes/base.py b/tests/remotes/base.py new file mode 100644 index 0000000000..ae8268f308 --- /dev/null +++ b/tests/remotes/base.py @@ -0,0 +1,23 @@ +from funcy import cached_property + + +class Base: + @staticmethod + def should_test(): + return True + + @staticmethod + def get_storagepath(): + raise NotImplementedError + + @staticmethod + def get_url(): + raise NotImplementedError + + @cached_property + def url(self): + return self.get_url() + + @cached_property + def config(self): + return {"url": self.url} diff --git a/tests/remotes/gdrive.py b/tests/remotes/gdrive.py new file mode 100644 index 0000000000..fb36b5fc79 --- /dev/null +++ b/tests/remotes/gdrive.py @@ -0,0 +1,55 @@ +import os +import uuid + +import pytest +from funcy import cached_property + +from dvc.remote.gdrive import GDriveRemoteTree + +from .base import Base + +TEST_GDRIVE_REPO_BUCKET = "root" + + +class GDrive(Base): + @staticmethod + def should_test(): + return os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) is not None + + @cached_property + def config(self): + return { + "url": self.url, + "gdrive_service_account_email": "test", + "gdrive_service_account_p12_file_path": "test.p12", + "gdrive_use_service_account": True, + } + + def __init__(self, dvc): + tree = GDriveRemoteTree(dvc, self.config) + tree._gdrive_create_dir("root", tree.path_info.path) + + @staticmethod + def get_storagepath(): + return TEST_GDRIVE_REPO_BUCKET + "/" + str(uuid.uuid4()) + + @staticmethod + def get_url(): + # NOTE: `get_url` should always return new random url + return "gdrive://" + GDrive.get_storagepath() + + +@pytest.fixture +def gdrive(make_tmp_dir): + if not GDrive.should_test(): + pytest.skip("no gdrive") + + # NOTE: temporary workaround + tmp_dir = make_tmp_dir("gdrive", dvc=True) + return GDrive(tmp_dir.dvc) + + +@pytest.fixture +def gdrive_remote(tmp_dir, dvc, gdrive): + tmp_dir.add_remote(config=gdrive.config) + return gdrive diff --git a/tests/remotes/gs.py b/tests/remotes/gs.py new file mode 100644 index 0000000000..dc5bcb42d5 --- /dev/null +++ b/tests/remotes/gs.py @@ -0,0 +1,90 @@ +import os +import uuid +from contextlib import contextmanager + +import pytest +from funcy import cached_property + +from dvc.remote.base import Remote +from dvc.remote.gs import GSRemoteTree +from dvc.utils import env2bool + +from .base import Base + +TEST_GCP_REPO_BUCKET = os.environ.get("DVC_TEST_GCP_REPO_BUCKET", "dvc-test") + +TEST_GCP_CREDS_FILE = os.path.abspath( + os.environ.get( + "GOOGLE_APPLICATION_CREDENTIALS", + os.path.join("scripts", "ci", "gcp-creds.json"), + ) +) +# Ensure that absolute path is used +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE + + +class GCP(Base): + @staticmethod + def should_test(): + from subprocess import CalledProcessError, check_output + + do_test = env2bool("DVC_TEST_GCP", undefined=None) + if do_test is not None: + return do_test + + if not os.path.exists(TEST_GCP_CREDS_FILE): + return False + + try: + check_output( + [ + "gcloud", + "auth", + "activate-service-account", + "--key-file", + TEST_GCP_CREDS_FILE, + ] + ) + except (CalledProcessError, OSError): + return False + return True + + @cached_property + def config(self): + return { + "url": self.url, + "credentialpath": TEST_GCP_CREDS_FILE, + } + + @staticmethod + def get_storagepath(): + return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) + + @staticmethod + def get_url(): + return "gs://" + GCP.get_storagepath() + + @classmethod + @contextmanager + def remote(cls, repo): + yield Remote(GSRemoteTree(repo, {"url": cls.get_url()})) + + @staticmethod + def put_objects(remote, objects): + client = remote.tree.gs + bucket = client.get_bucket(remote.path_info.bucket) + for key, body in objects.items(): + bucket.blob((remote.path_info / key).path).upload_from_string(body) + + +@pytest.fixture +def gs(): + if not GCP.should_test(): + pytest.skip("no gs") + yield GCP() + + +@pytest.fixture +def gs_remote(tmp_dir, dvc, gs): + tmp_dir.add_remote(config=gs.config) + yield gs diff --git a/tests/remotes/hdfs.py b/tests/remotes/hdfs.py new file mode 100644 index 0000000000..78f1e55017 --- /dev/null +++ b/tests/remotes/hdfs.py @@ -0,0 +1,55 @@ +import getpass +import os +import platform +from subprocess import CalledProcessError, Popen, check_output + +import pytest + +from .base import Base +from .local import Local + + +class HDFS(Base): + @staticmethod + def should_test(): + if platform.system() != "Linux": + return False + + try: + check_output( + ["hadoop", "version"], + shell=True, + executable=os.getenv("SHELL"), + ) + except (CalledProcessError, OSError): + return False + + p = Popen( + "hadoop fs -ls hdfs://127.0.0.1/", + shell=True, + executable=os.getenv("SHELL"), + ) + p.communicate() + if p.returncode != 0: + return False + + return True + + @staticmethod + def get_url(): + return "hdfs://{}@127.0.0.1{}".format( + getpass.getuser(), Local.get_storagepath() + ) + + +@pytest.fixture +def hdfs(): + if not HDFS.should_test(): + pytest.skip("no hadoop running") + yield HDFS() + + +@pytest.fixture +def hdfs_remote(tmp_dir, dvc, hdfs): + tmp_dir.add_remote(config=hdfs.config) + yield hdfs diff --git a/tests/remotes/http.py b/tests/remotes/http.py new file mode 100644 index 0000000000..0545932ccc --- /dev/null +++ b/tests/remotes/http.py @@ -0,0 +1,31 @@ +import pytest + +from .base import Base + + +class HTTP(Base): + @staticmethod + def get_url(port): + return f"http://127.0.0.1:{port}" + + def __init__(self, server): + self.url = self.get_url(server.server_port) + + +@pytest.fixture +def http_server(tmp_dir): + from tests.utils.httpd import PushRequestHandler, StaticFileServer + + with StaticFileServer(handler_class=PushRequestHandler) as httpd: + yield httpd + + +@pytest.fixture +def http(http_server): + yield HTTP(http_server) + + +@pytest.fixture +def http_remote(tmp_dir, dvc, http): + tmp_dir.add_remote(config=http.config) + yield http diff --git a/tests/remotes/local.py b/tests/remotes/local.py new file mode 100644 index 0000000000..b2df8f0899 --- /dev/null +++ b/tests/remotes/local.py @@ -0,0 +1,26 @@ +import pytest + +from tests.basic_env import TestDvc + +from .base import Base + + +class Local(Base): + @staticmethod + def get_storagepath(): + return TestDvc.mkdtemp() + + @staticmethod + def get_url(): + return Local.get_storagepath() + + +@pytest.fixture +def local_cloud(): + yield Local() + + +@pytest.fixture +def local_remote(tmp_dir, dvc, local_cloud): + tmp_dir.add_remote(config=local_cloud.config) + yield local_cloud diff --git a/tests/remotes/oss.py b/tests/remotes/oss.py new file mode 100644 index 0000000000..7fe87ca026 --- /dev/null +++ b/tests/remotes/oss.py @@ -0,0 +1,45 @@ +import os +import uuid + +import pytest + +from dvc.utils import env2bool + +from .base import Base + +TEST_OSS_REPO_BUCKET = "dvc-test" + + +class OSS(Base): + @staticmethod + def should_test(): + do_test = env2bool("DVC_TEST_OSS", undefined=None) + if do_test is not None: + return do_test + + return ( + os.getenv("OSS_ENDPOINT") + and os.getenv("OSS_ACCESS_KEY_ID") + and os.getenv("OSS_ACCESS_KEY_SECRET") + ) + + @staticmethod + def get_storagepath(): + return f"{TEST_OSS_REPO_BUCKET}/{uuid.uuid4()}" + + @staticmethod + def get_url(): + return f"oss://{OSS.get_storagepath()}" + + +@pytest.fixture +def oss(): + if not OSS.should_test(): + pytest.skip("no oss running") + yield OSS() + + +@pytest.fixture +def oss_remote(tmp_dir, dvc, oss): + tmp_dir.add_remote(config=oss.config) + yield oss diff --git a/tests/remotes/s3.py b/tests/remotes/s3.py new file mode 100644 index 0000000000..866cd6c37e --- /dev/null +++ b/tests/remotes/s3.py @@ -0,0 +1,68 @@ +import os +import uuid +from contextlib import contextmanager + +import pytest +from moto.s3 import mock_s3 + +from dvc.remote.base import Remote +from dvc.remote.s3 import S3RemoteTree +from dvc.utils import env2bool + +from .base import Base + +TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-temp") + + +class S3(Base): + @staticmethod + def should_test(): + do_test = env2bool("DVC_TEST_AWS", undefined=None) + if do_test is not None: + return do_test + + if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv( + "AWS_SECRET_ACCESS_KEY" + ): + return True + + return False + + @staticmethod + def get_storagepath(): + return TEST_AWS_REPO_BUCKET + "/" + str(uuid.uuid4()) + + @staticmethod + def get_url(): + return "s3://" + S3.get_storagepath() + + +@pytest.fixture +def s3(): + if not S3.should_test(): + pytest.skip("no s3") + yield S3() + + +@pytest.fixture +def s3_remote(tmp_dir, dvc, s3): + tmp_dir.add_remote(config=s3.config) + yield s3 + + +class S3Mocked(S3): + @classmethod + @contextmanager + def remote(cls, repo): + with mock_s3(): + yield Remote(S3RemoteTree(repo, {"url": cls.get_url()})) + + @staticmethod + def put_objects(remote, objects): + s3 = remote.tree.s3 + bucket = remote.path_info.bucket + s3.create_bucket(Bucket=bucket) + for key, body in objects.items(): + s3.put_object( + Bucket=bucket, Key=(remote.path_info / key).path, Body=body + ) diff --git a/tests/remotes/ssh.py b/tests/remotes/ssh.py new file mode 100644 index 0000000000..fc5da2966a --- /dev/null +++ b/tests/remotes/ssh.py @@ -0,0 +1,115 @@ +import getpass +import os +from subprocess import CalledProcessError, check_output + +import pytest +from funcy import cached_property + +from dvc.utils import env2bool + +from .base import Base +from .local import Local + +TEST_SSH_USER = "user" +TEST_SSH_KEY_PATH = os.path.join( + os.path.abspath(os.path.dirname(__file__)), f"{TEST_SSH_USER}.key" +) + + +class SSH: + @staticmethod + def should_test(): + do_test = env2bool("DVC_TEST_SSH", undefined=None) + if do_test is not None: + return do_test + + # FIXME: enable on windows + if os.name == "nt": + return False + + try: + check_output(["ssh", "-o", "BatchMode=yes", "127.0.0.1", "ls"]) + except (CalledProcessError, OSError): + return False + + return True + + @staticmethod + def get_url(): + return "ssh://{}@127.0.0.1:22{}".format( + getpass.getuser(), Local.get_storagepath() + ) + + +class SSHMocked(Base): + @staticmethod + def get_url(user, port): + path = Local.get_storagepath() + if os.name == "nt": + # NOTE: On Windows Local.get_storagepath() will return an + # ntpath that looks something like `C:\some\path`, which is not + # compatible with SFTP paths [1], so we need to convert it to + # a proper posixpath. + # To do that, we should construct a posixpath that would be + # relative to the server's root. + # In our case our ssh server is running with `c:/` as a root, + # and our URL format requires absolute paths, so the + # resulting path would look like `/some/path`. + # + # [1]https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6 + drive, path = os.path.splitdrive(path) + assert drive.lower() == "c:" + path = path.replace("\\", "/") + url = f"ssh://{user}@127.0.0.1:{port}{path}" + return url + + def __init__(self, server): + self.server = server + + @cached_property + def url(self): + return self.get_url(TEST_SSH_USER, self.server.port) + + @cached_property + def config(self): + return { + "url": self.url, + "keyfile": TEST_SSH_KEY_PATH, + } + + +@pytest.fixture(scope="session", autouse=True) +def ssh_server(): + import mockssh + + users = {TEST_SSH_USER: TEST_SSH_KEY_PATH} + with mockssh.Server(users) as s: + yield s + + +@pytest.fixture +def ssh_connection(ssh_server): + from dvc.remote.ssh.connection import SSHConnection + + yield SSHConnection( + host=ssh_server.host, + port=ssh_server.port, + username=TEST_SSH_USER, + key_filename=TEST_SSH_KEY_PATH, + ) + + +@pytest.fixture +def ssh(ssh_server, monkeypatch): + from dvc.remote.ssh import SSHRemoteTree + + # NOTE: see http://github.com/iterative/dvc/pull/3501 + monkeypatch.setattr(SSHRemoteTree, "CAN_TRAVERSE", False) + + return SSHMocked(ssh_server) + + +@pytest.fixture +def ssh_remote(tmp_dir, dvc, ssh): + tmp_dir.add_remote(config=ssh.config) + yield ssh diff --git a/tests/user.key b/tests/remotes/user.key similarity index 100% rename from tests/user.key rename to tests/remotes/user.key diff --git a/tests/user.key.pub b/tests/remotes/user.key.pub similarity index 100% rename from tests/user.key.pub rename to tests/remotes/user.key.pub diff --git a/tests/unit/remote/ssh/test_connection.py b/tests/unit/remote/ssh/test_connection.py index abaaa20516..10499a2316 100644 --- a/tests/unit/remote/ssh/test_connection.py +++ b/tests/unit/remote/ssh/test_connection.py @@ -12,39 +12,39 @@ here = os.path.abspath(os.path.dirname(__file__)) -def test_isdir(ssh): - assert ssh.isdir(here) - assert not ssh.isdir(__file__) +def test_isdir(ssh_connection): + assert ssh_connection.isdir(here) + assert not ssh_connection.isdir(__file__) -def test_exists(ssh): - assert not ssh.exists("/path/to/non/existent/file") - assert ssh.exists(__file__) +def test_exists(ssh_connection): + assert not ssh_connection.exists("/path/to/non/existent/file") + assert ssh_connection.exists(__file__) -def test_isfile(ssh): - assert ssh.isfile(__file__) - assert not ssh.isfile(here) +def test_isfile(ssh_connection): + assert ssh_connection.isfile(__file__) + assert not ssh_connection.isfile(here) -def test_makedirs(tmp_path, ssh): +def test_makedirs(tmp_path, ssh_connection): tmp = tmp_path.absolute().as_posix() path = posixpath.join(tmp, "dir", "subdir") - ssh.makedirs(path) + ssh_connection.makedirs(path) assert os.path.isdir(path) -def test_remove_dir(tmp_path, ssh): +def test_remove_dir(tmp_path, ssh_connection): dpath = tmp_path / "dir" dpath.mkdir() (dpath / "file").write_text("file") (dpath / "subdir").mkdir() (dpath / "subdir" / "subfile").write_text("subfile") - ssh.remove(dpath.absolute().as_posix()) + ssh_connection.remove(dpath.absolute().as_posix()) assert not dpath.exists() -def test_walk(tmp_path, ssh): +def test_walk(tmp_path, ssh_connection): root_path = tmp_path dir_path = root_path / "dir" subdir_path = dir_path / "subdir" @@ -75,7 +75,9 @@ def test_walk(tmp_path, ssh): expected = {entry.absolute().as_posix() for entry in entries} paths = set() - for root, dirs, files in ssh.walk(root_path.absolute().as_posix()): + for root, dirs, files in ssh_connection.walk( + root_path.absolute().as_posix() + ): for entry in dirs + files: paths.add(posixpath.join(root, entry)) @@ -87,9 +89,9 @@ def test_walk(tmp_path, ssh): not in ["xfs", "apfs", "btrfs"], reason="Reflinks only work in specified file systems", ) -def test_reflink(tmp_dir, ssh): +def test_reflink(tmp_dir, ssh_connection): tmp_dir.gen("foo", "foo content") - ssh.reflink("foo", "link") + ssh_connection.reflink("foo", "link") assert filecmp.cmp("foo", "link") assert not System.is_symlink("link") assert not System.is_hardlink("link") @@ -99,9 +101,9 @@ def test_reflink(tmp_dir, ssh): platform.system() == "Windows", reason="sftp symlink is not supported on Windows", ) -def test_symlink(tmp_dir, ssh): +def test_symlink(tmp_dir, ssh_connection): tmp_dir.gen("foo", "foo content") - ssh.symlink("foo", "link") + ssh_connection.symlink("foo", "link") assert System.is_symlink("link") @@ -109,9 +111,9 @@ def test_symlink(tmp_dir, ssh): platform.system() == "Windows", reason="hardlink is temporarily not supported on Windows", ) -def test_hardlink(tmp_dir, ssh): +def test_hardlink(tmp_dir, ssh_connection): tmp_dir.gen("foo", "foo content") - ssh.hardlink("foo", "link") + ssh_connection.hardlink("foo", "link") assert System.is_hardlink("link") @@ -119,14 +121,14 @@ def test_hardlink(tmp_dir, ssh): platform.system() == "Windows", reason="copy is temporarily not supported on Windows", ) -def test_copy(tmp_dir, ssh): +def test_copy(tmp_dir, ssh_connection): tmp_dir.gen("foo", "foo content") - ssh.copy("foo", "link") + ssh_connection.copy("foo", "link") assert filecmp.cmp("foo", "link") -def test_move(tmp_dir, ssh): +def test_move(tmp_dir, ssh_connection): tmp_dir.gen("foo", "foo content") - ssh.move("foo", "copy") + ssh_connection.move("foo", "copy") assert os.path.exists("copy") assert not os.path.exists("foo") diff --git a/tests/unit/remote/ssh/test_pool.py b/tests/unit/remote/ssh/test_pool.py index c3fb431423..f3094eeb82 100644 --- a/tests/unit/remote/ssh/test_pool.py +++ b/tests/unit/remote/ssh/test_pool.py @@ -2,13 +2,19 @@ from dvc.remote.pool import get_connection from dvc.remote.ssh.connection import SSHConnection +from tests.remotes.ssh import TEST_SSH_KEY_PATH, TEST_SSH_USER def test_doesnt_swallow_errors(ssh_server): class MyError(Exception): pass - with pytest.raises(MyError), get_connection( - SSHConnection, **ssh_server.test_creds - ): - raise MyError + with pytest.raises(MyError): + with get_connection( + SSHConnection, + host=ssh_server.host, + port=ssh_server.port, + username=TEST_SSH_USER, + key_filename=TEST_SSH_KEY_PATH, + ): + raise MyError diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index 27a3fd92bc..958c527300 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -7,7 +7,6 @@ from dvc.remote.ssh import SSHRemoteTree from dvc.system import System -from tests.remotes import SSHMocked def test_url(dvc): @@ -184,17 +183,8 @@ def test_ssh_gss_auth(mock_file, mock_exists, dvc, config, expected_gss_auth): assert tree.gss_auth == expected_gss_auth -def test_hardlink_optimization(dvc, tmp_dir, ssh_server): - port = ssh_server.test_creds["port"] - user = ssh_server.test_creds["username"] - - config = { - "url": SSHMocked.get_url(user, port), - "port": port, - "user": user, - "keyfile": ssh_server.test_creds["key_filename"], - } - tree = SSHRemoteTree(dvc, config) +def test_hardlink_optimization(dvc, tmp_dir, ssh): + tree = SSHRemoteTree(dvc, ssh.config) from_info = tree.path_info / "empty" to_info = tree.path_info / "link"