Skip to content

Commit

Permalink
dvc/utils/stream: make IterStream seekable and peekable (iterative#5084)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev authored Dec 28, 2020
1 parent 01a4473 commit c64b4ce
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
15 changes: 13 additions & 2 deletions dvc/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ class IterStream(io.RawIOBase):

def __init__(self, iterator): # pylint: disable=super-init-not-called
self.iterator = iterator
self.leftover = None
self.leftover = b""

def readable(self):
return True

def writable(self) -> bool:
return False

# Python 3 requires only .readinto() method, it still uses other ones
# under some circumstances and falls back if those are absent. Since
# iterator already constructs byte strings for us, .readinto() is not the
Expand Down Expand Up @@ -38,8 +41,16 @@ def read1(self, n=-1):

# Return an arbitrary number or bytes
if n <= 0:
self.leftover = None
self.leftover = b""
return chunk

output, self.leftover = chunk[:n], chunk[n:]
return output

def peek(self, n):
while len(self.leftover) < n:
try:
self.leftover += next(self.iterator)
except StopIteration:
break
return self.leftover[:n]
Binary file added tests/test.p12
Binary file not shown.
Binary file added tests/unit/remote/test.p12
Binary file not shown.
27 changes: 26 additions & 1 deletion tests/unit/utils/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from dvc.utils.http import open_url


def read_nbytes(fd, num, mode="r"):
data = b"" if mode == "rb" else ""
while len(data) < num:
chunk = fd.read(num - len(data))
if not chunk:
break
data += chunk
return data


def test_open_url(tmp_path, monkeypatch, http):
# Simulate bad connection
original_iter_content = requests.Response.iter_content
Expand All @@ -27,5 +37,20 @@ def bad_iter_content(self, *args, **kwargs):

with open_url((http / "sample.txt").url) as fd:
# Test various .read() variants
assert fd.read(len(text)) == text
assert read_nbytes(fd, len(text), mode="r") == text
assert fd.read() == text
assert fd.read() == ""


def test_open_url_peek_rb(tmp_path, monkeypatch, http):
# Goes over seek feature in 'rb' mode
text = "0123456789" * (io.DEFAULT_BUFFER_SIZE // 10 + 1)
http.gen("sample.txt", text * 2)

with open_url((http / "sample.txt").url, mode="rb") as fd:
text = text.encode("utf8")
assert fd.peek(len(text)) == text
assert read_nbytes(fd, len(text), mode="rb") == text
assert fd.peek(len(text)) == text
assert fd.read() == text
assert fd.read() == b""

0 comments on commit c64b4ce

Please sign in to comment.