From b5dd8c1d717000b54523ba2e06610e61ee81c0f6 Mon Sep 17 00:00:00 2001 From: Dmitry Petrov Date: Wed, 19 Apr 2017 03:07:04 -0700 Subject: [PATCH] Fix repro command. Close #31 --- bin/dvc-data-import | 2 +- dvc/command/base.py | 9 ++- dvc/command/import_bulk.py | 124 +++++++------------------------- dvc/command/import_file.py | 140 ++++++++++++++++++++++++++++++++++++ dvc/command/remove.py | 2 +- dvc/command/repro.py | 77 ++++++++++++++------ dvc/command/run.py | 34 ++++++--- dvc/git_wrapper.py | 7 +- dvc/path/data_item.py | 7 +- dvc/settings.py | 3 + dvc/state_file.py | 121 +++++++++++++++++-------------- dvc2.py | 14 ++-- tests/test_data_file_obj.py | 2 +- 13 files changed, 343 insertions(+), 199 deletions(-) create mode 100644 dvc/command/import_file.py diff --git a/bin/dvc-data-import b/bin/dvc-data-import index f1fff6a8cc..744cff498d 100755 --- a/bin/dvc-data-import +++ b/bin/dvc-data-import @@ -1,3 +1,3 @@ #!/bin/bash -PYTHONPATH=$DVC_HOME python $DVC_HOME/dvc/command/data_import.py $@ +PYTHONPATH=$DVC_HOME python $DVC_HOME/dvc/command/import_bulk.py $@ diff --git a/dvc/command/base.py b/dvc/command/base.py index 6aa6b8dfad..d8c78316fe 100644 --- a/dvc/command/base.py +++ b/dvc/command/base.py @@ -12,6 +12,7 @@ def __init__(self, settings): parser = argparse.ArgumentParser() self.define_args(parser) + self._parsed_args, self._command_args = parser.parse_known_args(args=self.args) @property @@ -49,7 +50,7 @@ def define_args(self, parser): pass def set_skip_git_actions(self, parser): - parser.add_argument('--skip-git-actions', '-s', action='store_true', + parser.add_argument('--skip-git-actions', '-s', action='store_true', default=False, help='Skip all git actions including reproducibility check and commits') parser.add_argument('--no-lock', '-L', action='store_true', default=False, help='Do not set DVC locker') @@ -58,10 +59,16 @@ def set_skip_git_actions(self, parser): def skip_git_actions(self): return self.parsed_args.skip_git_actions + def set_git_action(self, value): + self.parsed_args.skip_git_actions = not value + @property def is_locker(self): return not self.parsed_args.no_lock + def set_locker(self, value): + self.parsed_args.no_lock = value + def commit_if_needed(self, message, error=False): if error or self.skip_git_actions: self.not_committed_changes_warning() diff --git a/dvc/command/import_bulk.py b/dvc/command/import_bulk.py index e2ee61ea86..3c620e7a65 100644 --- a/dvc/command/import_bulk.py +++ b/dvc/command/import_bulk.py @@ -1,32 +1,22 @@ import os -from shutil import copyfile -import re import fasteners -import requests from dvc.command.base import CmdBase -from dvc.command.data_sync import sizeof_fmt +from dvc.command.import_file import CmdImportFile from dvc.logger import Logger -from dvc.exceptions import DvcException from dvc.runtime import Runtime -from dvc.state_file import StateFile -class DataImportError(DvcException): - def __init__(self, msg): - DvcException.__init__(self, 'Import error: {}'.format(msg)) - - -class CmdDataImport(CmdBase): +class CmdImportBulk(CmdBase): def __init__(self, settings): - super(CmdDataImport, self).__init__(settings) + super(CmdImportBulk, self).__init__(settings) def define_args(self, parser): self.set_skip_git_actions(parser) parser.add_argument('input', metavar='', - help='Input file', + help='Input files', nargs='*') self.add_string_arg(parser, 'output', 'Output file') @@ -47,93 +37,33 @@ def run(self): if not self.skip_git_actions and not self.git.is_ready_to_go(): return 1 - output = self.parsed_args.output - for file in self.parsed_args.input: - self.import_file(file, output, self.parsed_args.is_reproducible) + cmd = CmdImportFile(self.settings) + cmd.set_git_action(not self.skip_git_actions) + cmd.set_locker(False) - message = 'DVC data import: {} {}'.format(' '.join(self.parsed_args.input), self.parsed_args.output) - return self.commit_if_needed(message) + output = self.parsed_args.output + for input in self.parsed_args.input: + if not os.path.isdir(input): + cmd.import_and_commit_if_needed(input, output, self.parsed_args.is_reproducible) + else: + input_dir = os.path.basename(input) + for root, dirs, files in os.walk(input): + for file in files: + filename = os.path.join(root, file) + + rel = os.path.relpath(filename, input) + out = os.path.join(output, input_dir, rel) + + out_dir = os.path.dirname(out) + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + cmd.import_and_commit_if_needed(filename, out, self.parsed_args.is_reproducible) + pass finally: if self.is_locker: lock.release() pass - def import_file(self, input, output, is_reproducible): - if not CmdDataImport.is_url(input): - if not os.path.exists(input): - raise DataImportError('Input file "{}" does not exist'.format(input)) - if not os.path.isfile(input): - raise DataImportError('Input file "{}" has to be a regular file'.format(input)) - - if os.path.isdir(output): - output = os.path.join(output, os.path.basename(input)) - - data_item = self.settings.path_factory.data_item(output) - - if os.path.exists(data_item.data.relative): - raise DataImportError('Output file "{}" already exists'.format(data_item.data.relative)) - if not os.path.isdir(os.path.dirname(data_item.data.relative)): - raise DataImportError('Output file directory "{}" does not exists'.format( - os.path.dirname(data_item.data.relative))) - - cache_dir = os.path.dirname(data_item.cache.relative) - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - - if CmdDataImport.is_url(input): - Logger.debug('Downloading file {} ...'.format(input)) - self.download_file(input, data_item.cache.relative) - Logger.debug('Input file "{}" was downloaded to cache "{}"'.format( - input, data_item.cache.relative)) - else: - copyfile(input, data_item.cache.relative) - Logger.debug('Input file "{}" was copied to cache "{}"'.format( - input, data_item.cache.relative)) - - data_item.create_symlink() - Logger.debug('Symlink from data file "{}" to the cache file "{}" was created'. - format(data_item.data.relative, data_item.cache.relative)) - - state_file = StateFile(data_item.state.relative, - self.git, - [], - [output], - [], - is_reproducible) - state_file.save() - Logger.debug('State file "{}" was created'.format(data_item.state.relative)) - pass - - URL_REGEX = re.compile( - r'^(?:http|ftp)s?://' # http:// or https:// - r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... - r'localhost|' # localhost... - r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip - r'(?::\d+)?' # optional port - r'(?:/?|[/?]\S+)$', re.IGNORECASE) - - @staticmethod - def is_url(url): - return CmdDataImport.URL_REGEX.match(url) is not None - - @staticmethod - def download_file(from_url, to_file): - r = requests.get(from_url, stream=True) - - chunk_size = 1024 * 100 - downloaded = 0 - last_reported = 0 - report_bucket = 100*1024*1024 - with open(to_file, 'wb') as f: - for chunk in r.iter_content(chunk_size=chunk_size): - if chunk: # filter out keep-alive new chunks - downloaded += chunk_size - last_reported += chunk_size - if last_reported >= report_bucket: - last_reported = 0 - Logger.debug('Downloaded {}'.format(sizeof_fmt(downloaded))) - f.write(chunk) - return - if __name__ == '__main__': - Runtime.run(CmdDataImport) + Runtime.run(CmdImportBulk) diff --git a/dvc/command/import_file.py b/dvc/command/import_file.py new file mode 100644 index 0000000000..7cc2e447d7 --- /dev/null +++ b/dvc/command/import_file.py @@ -0,0 +1,140 @@ +import os +from shutil import copyfile +import re +import fasteners +import requests + +from dvc.command.base import CmdBase +from dvc.command.data_sync import sizeof_fmt +from dvc.logger import Logger +from dvc.exceptions import DvcException +from dvc.runtime import Runtime +from dvc.state_file import StateFile +from dvc.system import System + + +class ImportFileError(DvcException): + def __init__(self, msg): + DvcException.__init__(self, 'Import file: {}'.format(msg)) + + +class CmdImportFile(CmdBase): + def __init__(self, settings): + super(CmdImportFile, self).__init__(settings) + + def define_args(self, parser): + self.set_skip_git_actions(parser) + + self.add_string_arg(parser, 'input', 'Input file') + self.add_string_arg(parser, 'output', 'Output file') + + parser.add_argument('-i', '--is-reproducible', action='store_false', default=False, + help='Is data file reproducible') + pass + + def run(self): + if self.is_locker: + lock = fasteners.InterProcessLock(self.git.lock_file) + gotten = lock.acquire(timeout=5) + if not gotten: + Logger.info('Cannot perform the cmd since DVC is busy and locked. Please retry the cmd later.') + return 1 + + try: + return self.import_and_commit_if_needed(self.parsed_args.input, + self.parsed_args.output, + self.parsed_args.is_reproducible) + finally: + if self.is_locker: + lock.release() + pass + + def import_and_commit_if_needed(self, input, output, is_reproducible=True, check_if_ready=True): + if check_if_ready and not self.skip_git_actions and not self.git.is_ready_to_go(): + return 1 + + self.import_file(input, output, is_reproducible) + + message = 'DVC import file: {} {}'.format(' '.join(input), output) + return self.commit_if_needed(message) + + def import_file(self, input, output, is_reproducible=True): + if not CmdImportFile.is_url(input): + if not os.path.exists(input): + raise ImportFileError('Input file "{}" does not exist'.format(input)) + if not os.path.isfile(input): + raise ImportFileError('Input file "{}" has to be a regular file'.format(input)) + + if os.path.isdir(output): + output = os.path.join(output, os.path.basename(input)) + + data_item = self.settings.path_factory.data_item(output) + + if os.path.exists(data_item.data.relative): + raise ImportFileError('Output file "{}" already exists'.format(data_item.data.relative)) + if not os.path.isdir(os.path.dirname(data_item.data.relative)): + raise ImportFileError('Output file directory "{}" does not exists'.format( + os.path.dirname(data_item.data.relative))) + + cache_dir = os.path.dirname(data_item.cache.relative) + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + if CmdImportFile.is_url(input): + Logger.debug('Downloading file {} ...'.format(input)) + self.download_file(input, data_item.cache.relative) + Logger.debug('Input file "{}" was downloaded to cache "{}"'.format( + input, data_item.cache.relative)) + else: + copyfile(input, data_item.cache.relative) + Logger.debug('Input file "{}" was copied to cache "{}"'.format( + input, data_item.cache.relative)) + + Logger.debug('Creating symlink {} --> {}'.format(data_item.symlink_file, data_item.data.relative)) + System.symlink(data_item.symlink_file, data_item.data.relative) + + # import_file_argv = [StateFile.DVC_PYTHON_FILE_NAME, StateFile.COMMAND_IMPORT_FILE, input, output] + state_file = StateFile(StateFile.COMMAND_IMPORT_FILE, + data_item.state.relative, + self.settings, + argv=[input, output], + input_files=[], + output_files=[output], + is_reproducible=is_reproducible) + state_file.save() + Logger.debug('State file "{}" was created'.format(data_item.state.relative)) + pass + + URL_REGEX = re.compile( + r'^(?:http|ftp)s?://' # http:// or https:// + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... + r'localhost|' # localhost... + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip + r'(?::\d+)?' # optional port + r'(?:/?|[/?]\S+)$', re.IGNORECASE) + + @staticmethod + def is_url(url): + return CmdImportFile.URL_REGEX.match(url) is not None + + @staticmethod + def download_file(from_url, to_file): + r = requests.get(from_url, stream=True) + + chunk_size = 1024 * 100 + downloaded = 0 + last_reported = 0 + report_bucket = 100*1024*1024 + with open(to_file, 'wb') as f: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive new chunks + downloaded += chunk_size + last_reported += chunk_size + if last_reported >= report_bucket: + last_reported = 0 + Logger.debug('Downloaded {}'.format(sizeof_fmt(downloaded))) + f.write(chunk) + return + +if __name__ == '__main__': + Runtime.run(CmdDataImport) diff --git a/dvc/command/remove.py b/dvc/command/remove.py index cf3f8d001e..4734c6bd1a 100644 --- a/dvc/command/remove.py +++ b/dvc/command/remove.py @@ -105,7 +105,7 @@ def remove_file(self, target): def _remove_cloud_cache(self, data_item): if not self.parsed_args.keep_in_cloud: - aws_key = self.cache_file_aws_key(data_item.cache.dvc) + aws_key = self.cache_file_key(data_item.cache.dvc) self.remove_from_cloud(aws_key) def _remove_state_file(self, data_item): diff --git a/dvc/command/repro.py b/dvc/command/repro.py index c6cbf5aa15..9d6e0996ee 100644 --- a/dvc/command/repro.py +++ b/dvc/command/repro.py @@ -1,6 +1,8 @@ import os import fasteners +import copy +from dvc.command.import_file import CmdImportFile from dvc.command.run import CmdRun from dvc.logger import Logger from dvc.exceptions import DvcException @@ -72,7 +74,8 @@ def repro_target(self, target, force): return 1 if self.repro_data_items(data_item_list, force): - return self.commit_if_needed('DVC repro: {}'.format(' '.join(target))) + # return self.commit_if_needed('DVC repro: {}'.format(' '.join(target))) + return 0 pass def repro_data_items(self, data_item_list, force): @@ -82,6 +85,11 @@ def repro_data_items(self, data_item_list, force): for data_item in data_item_list: try: target_commit = self.git.get_target_commit(data_item.data.relative) + if target_commit is None: + msg = 'Data item "{}" cannot be reproduced: file not found or commit not found' + Logger.warn(msg.format(data_item.data.relative)) + continue + repro_change = ReproChange(data_item, self, target_commit) if repro_change.reproduce(force): changed = True @@ -117,16 +125,12 @@ def __init__(self, data_item, cmd_obj, target_commit): self._target_commit = target_commit - argv = self.state.norm_argv - - if not argv: - raise ReproError('Error: parameter {} is nor defined in state file "{}"'. - format(StateFile.PARAM_NORM_ARGV, data_item.state.relative)) - if len(argv) < 2: + if not self.state.argv: + raise ReproError('Error: parameter {} is not defined in state file "{}"'. + format(StateFile.PARAM_ARGV, data_item.state.relative)) + if len(self.state.argv) < 1: raise ReproError('Error: reproducible cmd in state file "{}" is too short'. format(self.state.file)) - - self._repro_argv = argv pass @property @@ -137,32 +141,63 @@ def cmd_obj(self): def state(self): return self._state - def reproduce_data_file(self): + def reproduce_data_item(self): Logger.debug('Reproducing data item "{}". Removing the file...'.format( self._data_item.data.dvc)) os.remove(self._data_item.data.relative) - Logger.debug('Reproducing data item "{}". Re-runs cmd: {}'.format( - self._data_item.data.relative, ' '.join(self._repro_argv))) - - data_items_from_args = self.cmd_obj.data_items_from_args(self._repro_argv) - return self.cmd_obj.run_command(self._repro_argv, - data_items_from_args, - self.state.stdout, - self.state.stderr) + settings = copy.copy(self._cmd_obj.settings) + settings.set_args(self.state.argv) + + if self.state.is_import_file: + Logger.debug('Reproducing data item "{}". Re-import cmd: {}'.format( + self._data_item.data.relative, ' '.join(self.state.argv))) + + if len(self.state.argv) != 2: + msg = 'Data item "{}" cannot be re-imported because of arguments number {} is incorrect. Argv: {}' + raise ReproError(msg.format(self._data_item.data.relative, len(self.state.argv), self.state.argv)) + + input = self.state.argv[0] + output = self.state.argv[1] + + cmd = CmdImportFile(settings) + cmd.set_git_action(True) + cmd.set_locker(False) + + if cmd.import_and_commit_if_needed(input, output, is_reproducible=True, check_if_ready=False) != 0: + raise ReproError('Import command reproduction failed') + return True + else: + Logger.debug('Reproducing data item "{}". Re-run cmd: {}'.format( + self._data_item.data.relative, ' '.join(self.state.argv))) + + cmd = CmdRun(settings) + cmd.set_git_action(True) + cmd.set_locker(False) + + data_items_from_args = self.cmd_obj.data_items_from_args(self.state.argv) + if cmd.run_and_commit_if_needed(self.state.argv, + data_items_from_args, + self.state.stdout, + self.state.stderr, + self.state.shell, + check_if_ready=False) != 0: + raise ReproError('Run command reproduction failed') + return True def reproduce(self, force=False): + dependencies = self.dependencies Logger.debug('Reproduce data item {} with dependencies, force={}: {}'.format( self._data_item.data.dvc, force, - ', '.join([x.data.dvc for x in self.dependencies]))) + ', '.join([x.data.dvc for x in dependencies]))) if not force and not self.state.is_reproducible: Logger.debug('Data item "{}" is not reproducible'.format(self._data_item.data.relative)) return False were_input_files_changed = False - for data_item in self.dependencies: + for data_item in dependencies: change = ReproChange(data_item, self._cmd_obj, self._target_commit) if change.reproduce(force): were_input_files_changed = True @@ -178,7 +213,7 @@ def reproduce(self, force=False): self._data_item.data.relative)) return False - return self.reproduce_data_file() + return self.reproduce_data_item() @property def dependencies(self): diff --git a/dvc/command/run.py b/dvc/command/run.py index 0654d1b7c3..5a0cebda84 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -62,19 +62,28 @@ def run(self): return 1 try: - if not self.skip_git_actions and not self.git.is_ready_to_go(): - return 1 - - self.run_command(self.command_args, - self.data_items_from_args(self.command_args), - self.parsed_args.stdout, - self.parsed_args.stderr, - self.parsed_args.shell) - return self.commit_if_needed('DVC run: {}'.format(' '.join(self.args))) + return self.run_and_commit_if_needed(self.command_args, + self.data_items_from_args(self.command_args), + self.parsed_args.stdout, + self.parsed_args.stderr, + self.parsed_args.shell) finally: if self.is_locker: lock.release() + def run_and_commit_if_needed(self, command_args, command_args_data_items, + stdout, stderr, shell, check_if_ready=True): + if check_if_ready and not self.skip_git_actions and not self.git.is_ready_to_go(): + return 1 + + self.run_command(command_args, + command_args_data_items, + stdout, + stderr, + shell) + + return self.commit_if_needed('DVC run: {}'.format(' '.join(self.args))) + def run_command(self, cmd_args, data_items_from_args, stdout=None, stderr=None, shell=False): Logger.debug('Run command with args: {}. Data items from args: {}. stdout={}, stderr={}, shell={}'.format( ' '.join(cmd_args), @@ -105,14 +114,17 @@ def run_command(self, cmd_args, data_items_from_args, stdout=None, stderr=None, Logger.debug('Create state file "{}"'.format(data_item.state.relative)) - state_file = StateFile(data_item.state.relative, self.git, + state_file = StateFile(StateFile.COMMAND_RUN, + data_item.state.relative, + self.settings, input_files_dvc, output_files_dvc, code_dependencies_dvc, argv=cmd_args, is_reproducible=self.is_reproducible, stdout=self._stdout_to_dvc(stdout), - stderr=self._stdout_to_dvc(stderr)) + stderr=self._stdout_to_dvc(stderr), + shell=shell) state_file.save() result.append(state_file) diff --git a/dvc/git_wrapper.py b/dvc/git_wrapper.py index 005cbfa969..3e110186e0 100644 --- a/dvc/git_wrapper.py +++ b/dvc/git_wrapper.py @@ -181,8 +181,11 @@ def were_files_changed(self, code_dependencies, path_factory, target_commit): @staticmethod def get_target_commit(file): - commit = Executor.exec_cmd_only_success(['git', 'log', '-1', '--pretty=format:"%h"', file]) - return commit.strip('"') + try: + commit = Executor.exec_cmd_only_success(['git', 'log', '-1', '--pretty=format:"%h"', file]) + return commit.strip('"') + except ExecutorError: + return None def separate_dependency_files_and_dirs(self, code_dependencies): code_files = [] diff --git a/dvc/path/data_item.py b/dvc/path/data_item.py index 6eaf4420f4..cb68e40c6f 100644 --- a/dvc/path/data_item.py +++ b/dvc/path/data_item.py @@ -101,17 +101,14 @@ def state_dir_abs(self): return os.path.join(self._git.git_dir_abs, self._config.state_dir) @property - def _symlink_file(self): + def symlink_file(self): data_file_dir = os.path.dirname(self.data.relative) return os.path.relpath(self.cache.relative, data_file_dir) - def create_symlink(self): - System.symlink(self._symlink_file, self.data.relative) - def move_data_to_cache(self): cache_dir = os.path.dirname(self.cache.relative) if not os.path.isdir(cache_dir): os.makedirs(cache_dir) shutil.move(self.data.relative, self.cache.relative) - System.symlink(self._symlink_file, self.data.relative) + System.symlink(self.symlink_file, self.data.relative) diff --git a/dvc/settings.py b/dvc/settings.py index 7233b03791..5a911e0fd7 100644 --- a/dvc/settings.py +++ b/dvc/settings.py @@ -27,6 +27,9 @@ def __init__(self, args, git, config): def args(self): return self._args + def set_args(self, args): + self._args = args + @property def git(self): return self._git diff --git a/dvc/state_file.py b/dvc/state_file.py index de906b58d6..88fbd57179 100644 --- a/dvc/state_file.py +++ b/dvc/state_file.py @@ -4,7 +4,7 @@ import time from dvc.exceptions import DvcException -from dvc.logger import Logger +from dvc.path.data_item import NotInDataDirError from dvc.system import System @@ -17,10 +17,17 @@ class StateFile(object): MAGIC = 'DVC-State' VERSION = '0.1' + DVC_PYTHON_FILE_NAME = 'dvc2.py' + DVC_COMMAND = 'dvc' + + COMMAND_RUN = 'run' + COMMAND_IMPORT_FILE = 'import-file' + ACCEPTED_COMMANDS = {COMMAND_IMPORT_FILE, COMMAND_RUN} + + PARAM_COMMAND = 'Command' PARAM_TYPE = 'Type' PARAM_VERSION = 'Version' PARAM_ARGV = 'Argv' - PARAM_NORM_ARGV = 'NormArgv' PARAM_CWD = 'Cwd' PARAM_CREATED_AT = 'CreatedAt' PARAM_INPUT_FILES = 'InputFiles' @@ -29,28 +36,35 @@ class StateFile(object): PARAM_NOT_REPRODUCIBLE = 'NotReproducible' PARAM_STDOUT = "Stdout" PARAM_STDERR = "Stderr" - - def __init__(self, file, git, input_files, output_files, + PARAM_SHELL = "Shell" + + def __init__(self, + command, + file, + settings, + input_files, + output_files, code_dependencies=[], is_reproducible=True, argv=sys.argv, stdout=None, stderr=None, - norm_argv=None, created_at=time.strftime('%Y-%m-%d %H:%M:%S %z'), - cwd=None): + cwd=None, + shell=False): self.file = file - self.git = git + self.settings = settings self.input_files = input_files self.output_files = output_files self.is_reproducible = is_reproducible self.code_dependencies = code_dependencies + self.shell = shell - self.argv = argv - if norm_argv: - self.norm_argv = norm_argv - else: - self.norm_argv = self.normalized_args() + if command not in self.ACCEPTED_COMMANDS: + raise StateFileError('Args error: unknown command %s' % command) + self.command = command + + self._argv = argv self.stdout = stdout self.stderr = stderr @@ -63,12 +77,25 @@ def __init__(self, file, git, input_files, output_files, self.cwd = self.get_dvc_path() pass + @property + def is_import_file(self): + return self.command == self.COMMAND_IMPORT_FILE + + @property + def is_run(self): + return self.command == self.COMMAND_RUN + + @property + def argv(self): + return self._argv + @staticmethod def load(filename, git): with open(filename, 'r') as fd: data = json.load(fd) - return StateFile(filename, + return StateFile(data.get(StateFile.PARAM_COMMAND), + filename, git, data.get(StateFile.PARAM_INPUT_FILES, []), data.get(StateFile.PARAM_OUTPUT_FILES, []), @@ -77,23 +104,27 @@ def load(filename, git): data.get(StateFile.PARAM_ARGV), data.get(StateFile.PARAM_STDOUT), data.get(StateFile.PARAM_STDERR), - data.get(StateFile.PARAM_NORM_ARGV), data.get(StateFile.PARAM_CREATED_AT), - data.get(StateFile.PARAM_CWD)) + data.get(StateFile.PARAM_CWD), + data.get(StateFile.PARAM_SHELL, False)) def save(self): + # cmd, argv = self.process_args(self._argv) + argv = self._argv_paths_normalization(self._argv) + res = { + self.PARAM_COMMAND: self.command, self.PARAM_TYPE: self.MAGIC, self.PARAM_VERSION: self.VERSION, - self.PARAM_ARGV: self.process_args(self.argv, 'argv'), - self.PARAM_NORM_ARGV: self.process_args(self.norm_argv, 'normalized argv'), + self.PARAM_ARGV: argv, self.PARAM_CWD: self.cwd, self.PARAM_CREATED_AT: self.created_at, self.PARAM_INPUT_FILES: self.input_files, self.PARAM_OUTPUT_FILES: self.output_files, self.PARAM_CODE_DEPENDENCIES: self.code_dependencies, self.PARAM_STDOUT: self.stdout, - self.PARAM_STDERR: self.stderr + self.PARAM_STDERR: self.stderr, + self.PARAM_SHELL: self.shell } if not self.is_reproducible: @@ -107,50 +138,32 @@ def save(self): json.dump(res, fd, indent=2) pass - def process_args(self, argv, name='argv'): - was_changed = False + # def process_args(self, argv): + # if len(argv) >= 2 and argv[0].endswith(self.DVC_PYTHON_FILE_NAME): + # if argv[1] in self.ACCEPTED_COMMANDS: + # return argv[1], self._argv_paths_normalization(argv[2:]) + # else: + # msg = 'File generation error: command "{}" is not allowed. Argv={}' + # raise StateFileError(msg.format(argv[1], argv)) + # else: + # msg = 'File generation error: dvc python command "{}" format error. Argv={}' + # raise StateFileError(msg.format(self.DVC_PYTHON_FILE_NAME, argv)) + + def _argv_paths_normalization(self, argv): result = [] for arg in argv: - if arg.endswith('dvc2.py'): - result.append('dvc') - was_changed = True - else: + try: + data_item = self.settings.path_factory.data_item(arg) + result.append(data_item.data.dvc) + except NotInDataDirError: result.append(arg) - if was_changed: - Logger.debug('Save state file {}. Replace {} "{}" to "{}"'.format( - self.file, - name, - argv, - result - )) - - return result - - def normalized_args(self): - result = [] - - if len(self.argv) > 0: - cmd = self.argv[0] - pos = cmd.rfind(os.sep) - if pos >= 0: - cmd = cmd[pos+1:] - result.append(cmd) - - for arg in self.argv[1:]: - if os.path.isfile(arg): # CHANGE to data items - path = os.path.abspath(arg) - dvc_path = os.path.relpath(path, self.git.git_dir_abs) - result.append(dvc_path) - else: - result.append(arg) - return result def get_dvc_path(self): pwd = System.get_cwd() - if not pwd.startswith(self.git.git_dir_abs): + if not pwd.startswith(self.settings.git.git_dir_abs): raise StateFileError('the file cannot be created outside of a git repository') - return os.path.relpath(pwd, self.git.git_dir_abs) + return os.path.relpath(pwd, self.settings.git.git_dir_abs) diff --git a/dvc2.py b/dvc2.py index 1647af7e81..0258444399 100755 --- a/dvc2.py +++ b/dvc2.py @@ -1,5 +1,7 @@ from __future__ import print_function +from dvc.command.import_file import CmdImportFile + """ main entry point / argument parsing for dvc @@ -10,12 +12,12 @@ from dvc.runtime import Runtime from dvc.command.init import CmdInit -from dvc.command.data_import import CmdDataImport +from dvc.command.import_bulk import CmdImportBulk from dvc.command.remove import CmdDataRemove from dvc.command.run import CmdRun from dvc.command.repro import CmdRepro from dvc.command.data_sync import CmdDataSync -from dvc.command.data_import import CmdDataImport +from dvc.command.import_bulk import CmdImportBulk from dvc.command.test import CmdTest @@ -40,7 +42,7 @@ def print_usage(): print('\n'.join(usage)) if __name__ == '__main__': - cmds = ['init', 'run', 'sync', 'repro', 'data', 'data-sync', 'data-remove', 'data-import', 'cloud', \ + cmds = ['init', 'run', 'sync', 'repro', 'data', 'data-sync', 'data-remove', 'import', 'import-file', 'cloud', \ 'cloud', 'cloud-run', 'cloud-instance-create', 'cloud-instance-remove', 'cloud-instance-describe', \ 'test', 'test-aws', 'test-gcloud', 'test-cloud'] cmds_expand = {'data': ['sync', 'remove', 'import'], @@ -75,8 +77,10 @@ def print_usage(): Runtime.run(CmdRepro, args_start_loc=2) elif cmd == 'data-sync' or (cmd == 'data' and subcmd == 'sync'): Runtime.run(CmdDataSync, args_start_loc=argv_offset) - elif cmd == 'data-import' or (cmd == 'data' and subcmd == 'import'): - Runtime.run(CmdDataImport, args_start_loc=argv_offset) + elif cmd == 'import': + Runtime.run(CmdImportBulk, args_start_loc=argv_offset) + elif cmd == 'import-file' or (cmd == 'data' and subcmd == 'import'): + Runtime.run(CmdImportFile, args_start_loc=argv_offset) elif cmd == 'data-remove' or (cmd == 'data' and subcmd == 'remove'): Runtime.run(CmdDataRemove, args_start_loc=argv_offset) elif cmd == 'cloud-run' or (cmd == 'cloud' and subcmd == 'run'): diff --git a/tests/test_data_file_obj.py b/tests/test_data_file_obj.py index 32ef904ff1..8397fb4f58 100644 --- a/tests/test_data_file_obj.py +++ b/tests/test_data_file_obj.py @@ -134,7 +134,7 @@ def test_state_file(self): def test_symlink(self): expected = os.path.join('..', '..', '..', 'ca', 'dir1', 'd2', 'file.txt_eeeff8f') - self.assertEqual(self.data_path._symlink_file, expected) + self.assertEqual(self.data_path.symlink_file, expected) def test_data_dir(self): data_path = self.path_factory.data_item(self.data_dir)