Skip to content

Commit

Permalink
gdrive: add open (iterative#3916)
Browse files Browse the repository at this point in the history
* gdrive: add open

Fixes iterative#3408
Related iterative#2865
Fixes iterative#3897

* dependency: add gdrive

* test: api: open: gdrive

* rollback pydrive dependency

* Revert "dependency: add gdrive"

This reverts commit fd33326.

* tests: remotes: GDrive: fix get_url

* tests: api: fully test GDrive

* minor typo fix

* attempt gdrive tests fix

* fix gdrive credentials config

* replace GDrivePathNotFound => FileMissingError

* fix test_open_external[GDrive]

* add ensure_dir_scm

* minor exception fix
  • Loading branch information
casperdcl authored Jun 13, 2020
1 parent 2ba73da commit 4d140a9
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 65 deletions.
7 changes: 5 additions & 2 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,12 @@ def __init__(self, etag, cached_etag):


class FileMissingError(DvcException):
def __init__(self, path):
def __init__(self, path, hint=None):
self.path = path
super().__init__(f"Can't find '{path}' neither locally nor on remote")
hint = "" if hint is None else f". {hint}"
super().__init__(
f"Can't find '{path}' neither locally nor on remote{hint}"
)


class DvcIgnoreInCollectedDirError(DvcException):
Expand Down
34 changes: 24 additions & 10 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,35 @@
import io
import logging
import os
import posixpath
import re
import threading
from collections import defaultdict
from contextlib import contextmanager
from urllib.parse import urlparse

from funcy import cached_property, retry, wrap_prop, wrap_with
from funcy.py3 import cat

from dvc.exceptions import DvcException
from dvc.exceptions import DvcException, FileMissingError
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.remote.base import BaseRemote, BaseRemoteTree
from dvc.scheme import Schemes
from dvc.utils import format_link, tmp_fname
from dvc.utils.stream import IterStream

logger = logging.getLogger(__name__)
FOLDER_MIME_TYPE = "application/vnd.google-apps.folder"


class GDrivePathNotFound(DvcException):
def __init__(self, path_info, hint):
hint = "" if hint is None else f" {hint}"
super().__init__(f"GDrive path '{path_info}' not found.{hint}")


class GDriveAuthError(DvcException):
def __init__(self, cred_location):

if cred_location:
message = (
"GDrive remote auth failed with credentials in '{}'.\n"
"Backup first, remove of fix them, and run DVC again.\n"
"Backup first, remove or fix them, and run DVC again.\n"
"It should do auth again and refresh the credentials.\n\n"
"Details:".format(cred_location)
)
Expand Down Expand Up @@ -389,6 +386,23 @@ def _gdrive_download_file(
) as pbar:
gdrive_file.GetContentFile(to_file, callback=pbar.update_to)

@contextmanager
@_gdrive_retry
def open(self, path_info, mode="r", encoding=None):
assert mode in {"r", "rt", "rb"}

item_id = self._get_item_id(path_info)
param = {"id": item_id}
# it does not create a file on the remote
gdrive_file = self._drive.CreateFile(param)
fd = gdrive_file.GetContentIOBuffer()
stream = IterStream(iter(fd))

if mode != "rb":
stream = io.TextIOWrapper(stream, encoding=encoding)

yield stream

@_gdrive_retry
def gdrive_delete_file(self, item_id):
from pydrive2.files import ApiRequestError
Expand Down Expand Up @@ -502,12 +516,12 @@ def _get_item_id(self, path_info, create=False, use_cache=True, hint=None):
return min(item_ids)

assert not create
raise GDrivePathNotFound(path_info, hint)
raise FileMissingError(path_info, hint)

def exists(self, path_info):
try:
self._get_item_id(path_info)
except GDrivePathNotFound:
except FileMissingError:
return False
else:
return True
Expand Down
46 changes: 2 additions & 44 deletions dvc/utils/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
from contextlib import contextmanager

from dvc.utils.stream import IterStream


@contextmanager
def open_url(url, mode="r", encoding=None):
Expand Down Expand Up @@ -61,47 +63,3 @@ def gen(response):
finally:
# Ensure connection is closed
it.close()


class IterStream(io.RawIOBase):
"""Wraps an iterator yielding bytes as a file object"""

def __init__(self, iterator):
self.iterator = iterator
self.leftover = None

def readable(self):
return True

# Python 3 requires only .readinto() method, it still uses other ones
# under some circumstances and falls back if those are absent. Since
# iterator already constructs byte strings for us, .readinto() is not the
# most optimal, so we provide .read1() too.

def readinto(self, b):
try:
n = len(b) # We're supposed to return at most this much
chunk = self.leftover or next(self.iterator)
output, self.leftover = chunk[:n], chunk[n:]

n_out = len(output)
b[:n_out] = output
return n_out
except StopIteration:
return 0 # indicate EOF

readinto1 = readinto

def read1(self, n=-1):
try:
chunk = self.leftover or next(self.iterator)
except StopIteration:
return b""

# Return an arbitrary number or bytes
if n <= 0:
self.leftover = None
return chunk

output, self.leftover = chunk[:n], chunk[n:]
return output
45 changes: 45 additions & 0 deletions dvc/utils/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import io


class IterStream(io.RawIOBase):
"""Wraps an iterator yielding bytes as a file object"""

def __init__(self, iterator):
self.iterator = iterator
self.leftover = None

def readable(self):
return True

# Python 3 requires only .readinto() method, it still uses other ones
# under some circumstances and falls back if those are absent. Since
# iterator already constructs byte strings for us, .readinto() is not the
# most optimal, so we provide .read1() too.

def readinto(self, b):
try:
n = len(b) # We're supposed to return at most this much
chunk = self.leftover or next(self.iterator)
output, self.leftover = chunk[:n], chunk[n:]

n_out = len(output)
b[:n_out] = output
return n_out
except StopIteration:
return 0 # indicate EOF

readinto1 = readinto

def read1(self, n=-1):
try:
chunk = self.leftover or next(self.iterator)
except StopIteration:
return b""

# Return an arbitrary number or bytes
if n <= 0:
self.leftover = None
return chunk

output, self.leftover = chunk[:n], chunk[n:]
return output
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(self):
# Extra dependencies for remote integrations

gs = ["google-cloud-storage==1.19.0"]
gdrive = ["pydrive2>=1.4.13"]
gdrive = ["pydrive2>=1.4.14"]
s3 = ["boto3>=1.9.201"]
azure = ["azure-storage-blob==2.1.0"]
oss = ["oss2==2.6.1"]
Expand Down
67 changes: 60 additions & 7 deletions tests/func/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@
from dvc.main import main
from dvc.path_info import URLInfo
from dvc.utils.fs import remove
from tests.remotes import GCP, HDFS, OSS, S3, SSH, Azure, Local

remote_params = [S3, GCP, Azure, OSS, SSH, HDFS]
from tests.remotes import (
GCP,
HDFS,
OSS,
S3,
SSH,
TEST_REMOTE,
Azure,
GDrive,
Local,
)

remote_params = [S3, GCP, Azure, GDrive, OSS, SSH, HDFS]
all_remote_params = [Local] + remote_params


Expand All @@ -25,9 +35,48 @@ def run_dvc(*argv):
assert main(argv) == 0


def ensure_dir(dvc, url):
if url.startswith("gdrive://"):
GDrive.create_dir(dvc, url)
run_dvc(
"remote",
"modify",
TEST_REMOTE,
"gdrive_service_account_email",
"test",
)
run_dvc(
"remote",
"modify",
TEST_REMOTE,
"gdrive_service_account_p12_file_path",
"test.p12",
)
run_dvc(
"remote",
"modify",
TEST_REMOTE,
"gdrive_use_service_account",
"True",
)


def ensure_dir_scm(dvc, url):
if url.startswith("gdrive://"):
GDrive.create_dir(dvc, url)
with dvc.config.edit() as conf:
conf["remote"][TEST_REMOTE].update(
gdrive_service_account_email="test",
gdrive_service_account_p12_file_path="test.p12",
gdrive_use_service_account=True,
)
dvc.scm.add(dvc.config.files["repo"])
dvc.scm.commit(f"modify '{TEST_REMOTE}' remote")


@pytest.mark.parametrize("remote_url", remote_params, indirect=True)
def test_get_url(tmp_dir, dvc, remote_url):
run_dvc("remote", "add", "-d", "upstream", remote_url)
run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url)
tmp_dir.dvc_gen("foo", "foo")

expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8"
Expand Down Expand Up @@ -58,7 +107,8 @@ def test_get_url_requires_dvc(tmp_dir, scm):

@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True)
def test_open(remote_url, tmp_dir, dvc):
run_dvc("remote", "add", "-d", "upstream", remote_url)
run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url)
ensure_dir(dvc, remote_url)
tmp_dir.dvc_gen("foo", "foo-text")
run_dvc("push")

Expand All @@ -72,6 +122,7 @@ def test_open(remote_url, tmp_dir, dvc):
@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True)
def test_open_external(remote_url, erepo_dir, setup_remote):
setup_remote(erepo_dir.dvc, url=remote_url)
ensure_dir_scm(erepo_dir.dvc, remote_url)

with erepo_dir.chdir():
erepo_dir.dvc_gen("version", "master", commit="add version")
Expand All @@ -95,7 +146,8 @@ def test_open_external(remote_url, erepo_dir, setup_remote):

@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True)
def test_open_granular(remote_url, tmp_dir, dvc):
run_dvc("remote", "add", "-d", "upstream", remote_url)
run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url)
ensure_dir(dvc, remote_url)
tmp_dir.dvc_gen({"dir": {"foo": "foo-text"}})
run_dvc("push")

Expand All @@ -109,7 +161,8 @@ def test_open_granular(remote_url, tmp_dir, dvc):
@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True)
def test_missing(remote_url, tmp_dir, dvc):
tmp_dir.dvc_gen("foo", "foo")
run_dvc("remote", "add", "-d", "upstream", remote_url)
run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url)
ensure_dir(dvc, remote_url)

# Remove cache to make foo missing
remove(dvc.cache.local.cache_dir)
Expand Down
7 changes: 6 additions & 1 deletion tests/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-temp")
TEST_GCP_REPO_BUCKET = os.environ.get("DVC_TEST_GCP_REPO_BUCKET", "dvc-test")
TEST_GDRIVE_REPO_BUCKET = "root"
TEST_OSS_REPO_BUCKET = "dvc-test"

TEST_GCP_CREDS_FILE = os.path.abspath(
Expand Down Expand Up @@ -152,10 +153,14 @@ def create_dir(dvc, url):
remote = GDriveRemote(dvc, config)
remote.tree._gdrive_create_dir("root", remote.path_info.path)

@staticmethod
def get_storagepath():
return TEST_GDRIVE_REPO_BUCKET + "/" + str(uuid.uuid4())

@staticmethod
def get_url():
# NOTE: `get_url` should always return new random url
return "gdrive://root/" + str(uuid.uuid4())
return "gdrive://" + GDrive.get_storagepath()


class Azure:
Expand Down

0 comments on commit 4d140a9

Please sign in to comment.