Skip to content

Commit

Permalink
Fix repro command. Close iterative#31
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov committed Apr 19, 2017
1 parent cf3695d commit b5dd8c1
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 199 deletions.
2 changes: 1 addition & 1 deletion bin/dvc-data-import
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

PYTHONPATH=$DVC_HOME python $DVC_HOME/dvc/command/data_import.py $@
PYTHONPATH=$DVC_HOME python $DVC_HOME/dvc/command/import_bulk.py $@
9 changes: 8 additions & 1 deletion dvc/command/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, settings):

parser = argparse.ArgumentParser()
self.define_args(parser)

self._parsed_args, self._command_args = parser.parse_known_args(args=self.args)

@property
Expand Down Expand Up @@ -49,7 +50,7 @@ def define_args(self, parser):
pass

def set_skip_git_actions(self, parser):
parser.add_argument('--skip-git-actions', '-s', action='store_true',
parser.add_argument('--skip-git-actions', '-s', action='store_true', default=False,
help='Skip all git actions including reproducibility check and commits')
parser.add_argument('--no-lock', '-L', action='store_true', default=False,
help='Do not set DVC locker')
Expand All @@ -58,10 +59,16 @@ def set_skip_git_actions(self, parser):
def skip_git_actions(self):
return self.parsed_args.skip_git_actions

def set_git_action(self, value):
self.parsed_args.skip_git_actions = not value

@property
def is_locker(self):
return not self.parsed_args.no_lock

def set_locker(self, value):
self.parsed_args.no_lock = value

def commit_if_needed(self, message, error=False):
if error or self.skip_git_actions:
self.not_committed_changes_warning()
Expand Down
124 changes: 27 additions & 97 deletions dvc/command/import_bulk.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
import os
from shutil import copyfile
import re
import fasteners
import requests

from dvc.command.base import CmdBase
from dvc.command.data_sync import sizeof_fmt
from dvc.command.import_file import CmdImportFile
from dvc.logger import Logger
from dvc.exceptions import DvcException
from dvc.runtime import Runtime
from dvc.state_file import StateFile


class DataImportError(DvcException):
def __init__(self, msg):
DvcException.__init__(self, 'Import error: {}'.format(msg))


class CmdDataImport(CmdBase):
class CmdImportBulk(CmdBase):
def __init__(self, settings):
super(CmdDataImport, self).__init__(settings)
super(CmdImportBulk, self).__init__(settings)

def define_args(self, parser):
self.set_skip_git_actions(parser)

parser.add_argument('input',
metavar='',
help='Input file',
help='Input files',
nargs='*')

self.add_string_arg(parser, 'output', 'Output file')
Expand All @@ -47,93 +37,33 @@ def run(self):
if not self.skip_git_actions and not self.git.is_ready_to_go():
return 1

output = self.parsed_args.output
for file in self.parsed_args.input:
self.import_file(file, output, self.parsed_args.is_reproducible)
cmd = CmdImportFile(self.settings)
cmd.set_git_action(not self.skip_git_actions)
cmd.set_locker(False)

message = 'DVC data import: {} {}'.format(' '.join(self.parsed_args.input), self.parsed_args.output)
return self.commit_if_needed(message)
output = self.parsed_args.output
for input in self.parsed_args.input:
if not os.path.isdir(input):
cmd.import_and_commit_if_needed(input, output, self.parsed_args.is_reproducible)
else:
input_dir = os.path.basename(input)
for root, dirs, files in os.walk(input):
for file in files:
filename = os.path.join(root, file)

rel = os.path.relpath(filename, input)
out = os.path.join(output, input_dir, rel)

out_dir = os.path.dirname(out)
if not os.path.exists(out_dir):
os.mkdir(out_dir)

cmd.import_and_commit_if_needed(filename, out, self.parsed_args.is_reproducible)
pass
finally:
if self.is_locker:
lock.release()
pass

def import_file(self, input, output, is_reproducible):
if not CmdDataImport.is_url(input):
if not os.path.exists(input):
raise DataImportError('Input file "{}" does not exist'.format(input))
if not os.path.isfile(input):
raise DataImportError('Input file "{}" has to be a regular file'.format(input))

if os.path.isdir(output):
output = os.path.join(output, os.path.basename(input))

data_item = self.settings.path_factory.data_item(output)

if os.path.exists(data_item.data.relative):
raise DataImportError('Output file "{}" already exists'.format(data_item.data.relative))
if not os.path.isdir(os.path.dirname(data_item.data.relative)):
raise DataImportError('Output file directory "{}" does not exists'.format(
os.path.dirname(data_item.data.relative)))

cache_dir = os.path.dirname(data_item.cache.relative)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

if CmdDataImport.is_url(input):
Logger.debug('Downloading file {} ...'.format(input))
self.download_file(input, data_item.cache.relative)
Logger.debug('Input file "{}" was downloaded to cache "{}"'.format(
input, data_item.cache.relative))
else:
copyfile(input, data_item.cache.relative)
Logger.debug('Input file "{}" was copied to cache "{}"'.format(
input, data_item.cache.relative))

data_item.create_symlink()
Logger.debug('Symlink from data file "{}" to the cache file "{}" was created'.
format(data_item.data.relative, data_item.cache.relative))

state_file = StateFile(data_item.state.relative,
self.git,
[],
[output],
[],
is_reproducible)
state_file.save()
Logger.debug('State file "{}" was created'.format(data_item.state.relative))
pass

URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)

@staticmethod
def is_url(url):
return CmdDataImport.URL_REGEX.match(url) is not None

@staticmethod
def download_file(from_url, to_file):
r = requests.get(from_url, stream=True)

chunk_size = 1024 * 100
downloaded = 0
last_reported = 0
report_bucket = 100*1024*1024
with open(to_file, 'wb') as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
downloaded += chunk_size
last_reported += chunk_size
if last_reported >= report_bucket:
last_reported = 0
Logger.debug('Downloaded {}'.format(sizeof_fmt(downloaded)))
f.write(chunk)
return

if __name__ == '__main__':
Runtime.run(CmdDataImport)
Runtime.run(CmdImportBulk)
140 changes: 140 additions & 0 deletions dvc/command/import_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
from shutil import copyfile
import re
import fasteners
import requests

from dvc.command.base import CmdBase
from dvc.command.data_sync import sizeof_fmt
from dvc.logger import Logger
from dvc.exceptions import DvcException
from dvc.runtime import Runtime
from dvc.state_file import StateFile
from dvc.system import System


class ImportFileError(DvcException):
def __init__(self, msg):
DvcException.__init__(self, 'Import file: {}'.format(msg))


class CmdImportFile(CmdBase):
def __init__(self, settings):
super(CmdImportFile, self).__init__(settings)

def define_args(self, parser):
self.set_skip_git_actions(parser)

self.add_string_arg(parser, 'input', 'Input file')
self.add_string_arg(parser, 'output', 'Output file')

parser.add_argument('-i', '--is-reproducible', action='store_false', default=False,
help='Is data file reproducible')
pass

def run(self):
if self.is_locker:
lock = fasteners.InterProcessLock(self.git.lock_file)
gotten = lock.acquire(timeout=5)
if not gotten:
Logger.info('Cannot perform the cmd since DVC is busy and locked. Please retry the cmd later.')
return 1

try:
return self.import_and_commit_if_needed(self.parsed_args.input,
self.parsed_args.output,
self.parsed_args.is_reproducible)
finally:
if self.is_locker:
lock.release()
pass

def import_and_commit_if_needed(self, input, output, is_reproducible=True, check_if_ready=True):
if check_if_ready and not self.skip_git_actions and not self.git.is_ready_to_go():
return 1

self.import_file(input, output, is_reproducible)

message = 'DVC import file: {} {}'.format(' '.join(input), output)
return self.commit_if_needed(message)

def import_file(self, input, output, is_reproducible=True):
if not CmdImportFile.is_url(input):
if not os.path.exists(input):
raise ImportFileError('Input file "{}" does not exist'.format(input))
if not os.path.isfile(input):
raise ImportFileError('Input file "{}" has to be a regular file'.format(input))

if os.path.isdir(output):
output = os.path.join(output, os.path.basename(input))

data_item = self.settings.path_factory.data_item(output)

if os.path.exists(data_item.data.relative):
raise ImportFileError('Output file "{}" already exists'.format(data_item.data.relative))
if not os.path.isdir(os.path.dirname(data_item.data.relative)):
raise ImportFileError('Output file directory "{}" does not exists'.format(
os.path.dirname(data_item.data.relative)))

cache_dir = os.path.dirname(data_item.cache.relative)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

if CmdImportFile.is_url(input):
Logger.debug('Downloading file {} ...'.format(input))
self.download_file(input, data_item.cache.relative)
Logger.debug('Input file "{}" was downloaded to cache "{}"'.format(
input, data_item.cache.relative))
else:
copyfile(input, data_item.cache.relative)
Logger.debug('Input file "{}" was copied to cache "{}"'.format(
input, data_item.cache.relative))

Logger.debug('Creating symlink {} --> {}'.format(data_item.symlink_file, data_item.data.relative))
System.symlink(data_item.symlink_file, data_item.data.relative)

# import_file_argv = [StateFile.DVC_PYTHON_FILE_NAME, StateFile.COMMAND_IMPORT_FILE, input, output]
state_file = StateFile(StateFile.COMMAND_IMPORT_FILE,
data_item.state.relative,
self.settings,
argv=[input, output],
input_files=[],
output_files=[output],
is_reproducible=is_reproducible)
state_file.save()
Logger.debug('State file "{}" was created'.format(data_item.state.relative))
pass

URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)

@staticmethod
def is_url(url):
return CmdImportFile.URL_REGEX.match(url) is not None

@staticmethod
def download_file(from_url, to_file):
r = requests.get(from_url, stream=True)

chunk_size = 1024 * 100
downloaded = 0
last_reported = 0
report_bucket = 100*1024*1024
with open(to_file, 'wb') as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
downloaded += chunk_size
last_reported += chunk_size
if last_reported >= report_bucket:
last_reported = 0
Logger.debug('Downloaded {}'.format(sizeof_fmt(downloaded)))
f.write(chunk)
return

if __name__ == '__main__':
Runtime.run(CmdDataImport)
2 changes: 1 addition & 1 deletion dvc/command/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def remove_file(self, target):

def _remove_cloud_cache(self, data_item):
if not self.parsed_args.keep_in_cloud:
aws_key = self.cache_file_aws_key(data_item.cache.dvc)
aws_key = self.cache_file_key(data_item.cache.dvc)
self.remove_from_cloud(aws_key)

def _remove_state_file(self, data_item):
Expand Down
Loading

0 comments on commit b5dd8c1

Please sign in to comment.