Skip to content

Commit

Permalink
Fixed seek offset size to 64bit. (#27125)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch/pytorch#26998.
Pull Request resolved: pytorch/pytorch#27125

Differential Revision: D17687154

Pulled By: ezyang

fbshipit-source-id: 6784f4fd799130ac72a25884f120a0ba96bd4f51
  • Loading branch information
peterjc123 authored and facebook-github-bot committed Oct 1, 2019
1 parent 48cd66c commit ec07d14
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
6 changes: 4 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5019,7 +5019,7 @@ def test_serialization_gzip(self):

def test_serialization_offset(self):
a = torch.randn(5, 5)
b = torch.randn(2, 2)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
m = torch.nn.Conv2d(1, 1, (1, 3))
i, j = 41, 43
with tempfile.NamedTemporaryFile() as f:
Expand All @@ -5028,6 +5028,7 @@ def test_serialization_offset(self):
pickle.dump(j, f)
torch.save(b, f)
torch.save(m, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
Expand All @@ -5042,13 +5043,14 @@ def test_serialization_offset(self):

def test_serialization_offset_filelike(self):
a = torch.randn(5, 5)
b = torch.randn(2, 3)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
i, j = 41, 43
with BytesIOContext() as f:
pickle.dump(i, f)
torch.save(a, f)
pickle.dump(j, f)
torch.save(b, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
Expand Down
16 changes: 11 additions & 5 deletions torch/csrc/generic/StorageMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include <cuda_runtime.h>
#endif

#ifdef _MSC_VER
#define LSEEK _lseeki64
#else
#define LSEEK lseek
#endif

static PyObject * THPStorage_(size)(THPStorage *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -251,9 +257,9 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)

// file is backed by a fd
const int fd = PyObject_AsFileDescriptor(file);
const auto fd_original_pos = lseek(fd, 0, SEEK_CUR);
const auto fd_original_pos = LSEEK(fd, 0, SEEK_CUR);
if (offset != Py_None) {
lseek(fd, THPUtils_unpackLong(offset), SEEK_SET);
LSEEK(fd, THPUtils_unpackLong(offset), SEEK_SET);
}
THPUtils_assert(fd != -1, "_set_from_file couldn't retrieve a file "
"descriptor from given object");
Expand All @@ -265,9 +271,9 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
// the file descriptor is returned to original position and
// the file handle at python call-site needs updating to the
// advanced postion
const auto fd_current_pos = lseek(fd, 0, SEEK_CUR);
lseek(fd, fd_original_pos, SEEK_SET);
const auto seek_return = PyObject_CallMethod(file, "seek", "li", (long)fd_current_pos, 0);
const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR);
LSEEK(fd, fd_original_pos, SEEK_SET);
const auto seek_return = PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0);
if (seek_return == nullptr) {
return nullptr;
}
Expand Down

0 comments on commit ec07d14

Please sign in to comment.