Skip to content

Commit

Permalink
New option jobs for dvc import (iterative#4977)
Browse files Browse the repository at this point in the history
* fix iterative#4838

`jobs` option for `dvc import`

* option `jobs` effected

* Pass the unit tests parse

* Adding a functional test.

* Pass `jobs` to `run` and remove some pylint

* Update dvc/stage/__init__.py

Co-authored-by: Saugat Pachhai <[email protected]>

* remove unused param

* Solve test fail on windows

Co-authored-by: Saugat Pachhai <[email protected]>
  • Loading branch information
karajan1001 and skshetry authored Dec 11, 2020
1 parent 7f52170 commit 96a97ec
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 11 deletions.
12 changes: 12 additions & 0 deletions dvc/command/imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run(self):
rev=self.args.rev,
no_exec=self.args.no_exec,
desc=self.args.desc,
jobs=self.args.jobs,
)
except DvcException:
logger.exception(
Expand Down Expand Up @@ -82,4 +83,15 @@ def add_parser(subparsers, parent_parser):
"This doesn't affect any DVC operations."
),
)
import_parser.add_argument(
"-j",
"--jobs",
type=int,
help=(
"Number of jobs to run simultaneously. "
"The default value is 4 * cpu_count(). "
"For SSH remotes, the default is 4. "
),
metavar="<number>",
)
import_parser.set_defaults(func=CmdImport)
4 changes: 2 additions & 2 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def save(self):
def dumpd(self):
return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo}

def download(self, to):
def download(self, to, jobs=None):
cache = self.repo.cache.local

with self._make_repo(cache_dir=cache.cache_dir) as repo:
if self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev()

_, _, cache_infos = repo.fetch_external([self.def_path])
_, _, cache_infos = repo.fetch_external([self.def_path], jobs=jobs)

cache.checkout(to.path_info, cache_infos[0])

Expand Down
4 changes: 2 additions & 2 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def dumpd(self):
def verify_metric(self):
raise DvcException(f"verify metric is not supported for {self.scheme}")

def download(self, to):
self.tree.download(self.path_info, to.path_info)
def download(self, to, jobs=None):
self.tree.download(self.path_info, to.path_info, jobs=jobs)

def checkout(
self,
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/imp_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def imp_url(
frozen=True,
no_exec=False,
desc=None,
jobs=None,
):
from dvc.dvcfile import Dvcfile
from dvc.stage import Stage, create_stage, restore_meta
Expand Down Expand Up @@ -61,7 +62,7 @@ def imp_url(
if no_exec:
stage.ignore_outs()
else:
stage.run()
stage.run(jobs=jobs)

stage.frozen = frozen

Expand Down
3 changes: 2 additions & 1 deletion dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ def run(
self.remove_outs(ignore_remove=False, force=False)

if not self.frozen and self.is_import:
sync_import(self, dry, force)
jobs = kwargs.get("jobs", None)
sync_import(self, dry, force, jobs)
elif not self.frozen and self.cmd:
run_stage(self, dry, force, **kwargs)
else:
Expand Down
4 changes: 2 additions & 2 deletions dvc/stage/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def update_import(stage, rev=None):
stage.frozen = frozen


def sync_import(stage, dry=False, force=False):
def sync_import(stage, dry=False, force=False, jobs=None):
"""Synchronize import's outs to the workspace."""
logger.info(
"Importing '{dep}' -> '{out}'".format(
Expand All @@ -27,4 +27,4 @@ def sync_import(stage, dry=False, force=False):
stage.outs[0].checkout()
else:
stage.save_deps()
stage.deps[0].download(stage.outs[0])
stage.deps[0].download(stage.outs[0], jobs=jobs)
21 changes: 18 additions & 3 deletions dvc/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def download(
no_progress_bar=False,
file_mode=None,
dir_mode=None,
jobs=None,
):
if not hasattr(self, "_download"):
raise RemoteActionNotImplemented("download", self.scheme)
Expand All @@ -406,14 +407,27 @@ def download(

if self.isdir(from_info):
return self._download_dir(
from_info, to_info, name, no_progress_bar, file_mode, dir_mode
from_info,
to_info,
name,
no_progress_bar,
file_mode,
dir_mode,
jobs,
)
return self._download_file(
from_info, to_info, name, no_progress_bar, file_mode, dir_mode
)

def _download_dir(
self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode
self,
from_info,
to_info,
name,
no_progress_bar,
file_mode,
dir_mode,
jobs,
):
from_infos = list(self.walk_files(from_info))
to_infos = (
Expand All @@ -435,7 +449,8 @@ def _download_dir(
dir_mode=dir_mode,
)
)
with ThreadPoolExecutor(max_workers=self.jobs) as executor:
max_workers = jobs or self.jobs
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(download_files, from_info, to_info)
for from_info, to_info in zip(from_infos, to_infos)
Expand Down
22 changes: 22 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,25 @@ def test_import_with_no_exec(tmp_dir, dvc, erepo_dir):

dst = tmp_dir / "foo_imported"
assert not dst.exists()


def test_import_with_jobs(mocker, dvc, erepo_dir):
from dvc.data_cloud import DataCloud

with erepo_dir.chdir():
erepo_dir.dvc_gen(
{
"dir1": {
"file1": "file1",
"file2": "file2",
"file3": "file3",
"file4": "file4",
},
},
commit="init",
)

spy = mocker.spy(DataCloud, "pull")
dvc.imp(os.fspath(erepo_dir), "dir1", jobs=3)
run_jobs = tuple(spy.call_args_list[0])[1].get("jobs")
assert run_jobs == 3
4 changes: 4 additions & 0 deletions tests/unit/command/test_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_import(mocker):
"version",
"--desc",
"description",
"--jobs",
"3",
]
)
assert cli_args.func == CmdImport
Expand All @@ -33,6 +35,7 @@ def test_import(mocker):
rev="version",
no_exec=False,
desc="description",
jobs=3,
)


Expand Down Expand Up @@ -67,4 +70,5 @@ def test_import_no_exec(mocker):
rev="version",
no_exec=True,
desc="description",
jobs=None,
)

0 comments on commit 96a97ec

Please sign in to comment.