Skip to content

Commit

Permalink
ARROW-16048: [Python] Avoid exposing null buffer address to the Pytho…
Browse files Browse the repository at this point in the history
…n buffer protocol

A 0-size buffer created in C++ can very well have a null address. However, when exporting that buffer through the Python buffer API, we should ensure we pass a valid pointer.

Also, ensure mutability of a buffer is preserved when pickling and unpickling.

Closes apache#12752 from emkornfield/fix_bug

Lead-authored-by: Antoine Pitrou <[email protected]>
Co-authored-by: Micah Kornfield <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
pitrou and emkornfield committed Apr 21, 2022
1 parent 08ab8b0 commit 1dccb56
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
22 changes: 18 additions & 4 deletions python/pyarrow/io.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ from pyarrow.util import _is_path_like, _stringify_path
DEFAULT_BUFFER_SIZE = 2 ** 16


# To let us get a PyObject* and avoid Cython auto-ref-counting
cdef extern from "Python.h":
# To let us get a PyObject* and avoid Cython auto-ref-counting
PyObject* PyBytes_FromStringAndSizeNative" PyBytes_FromStringAndSize"(
char *v, Py_ssize_t len) except NULL

# Workaround https://github.com/cython/cython/issues/4707
bytearray PyByteArray_FromStringAndSize(char *string, Py_ssize_t len)


def io_thread_count():
"""
Expand Down Expand Up @@ -1117,9 +1120,16 @@ cdef class Buffer(_Weakrefable):

def __reduce_ex__(self, protocol):
if protocol >= 5:
return py_buffer, (builtin_pickle.PickleBuffer(self),)
bufobj = builtin_pickle.PickleBuffer(self)
elif self.buffer.get().is_mutable():
# Need to pass a bytearray to recreate a mutable buffer when
# unpickling.
bufobj = PyByteArray_FromStringAndSize(
<const char*>self.buffer.get().data(),
self.buffer.get().size())
else:
return py_buffer, (self.to_pybytes(),)
bufobj = self.to_pybytes()
return py_buffer, (bufobj,)

def to_pybytes(self):
"""
Expand All @@ -1138,10 +1148,14 @@ cdef class Buffer(_Weakrefable):
"buffer was not mutable")
buffer.readonly = 1
buffer.buf = <char *>self.buffer.get().data()
buffer.len = self.size
if buffer.buf == NULL:
# ARROW-16048: Ensure we don't export a NULL address.
assert buffer.len == 0
buffer.buf = cp.PyBytes_AS_STRING(b"")
buffer.format = 'b'
buffer.internal = NULL
buffer.itemsize = 1
buffer.len = self.size
buffer.ndim = 1
buffer.obj = self
buffer.shape = self.shape
Expand Down
35 changes: 28 additions & 7 deletions python/pyarrow/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def test_python_file_closing():
# Buffers


def check_buffer_pickling(buf):
# Check that buffer survives a pickle roundtrip
for protocol in range(0, pickle.HIGHEST_PROTOCOL + 1):
result = pickle.loads(pickle.dumps(buf, protocol=protocol))
assert len(result) == len(buf)
assert memoryview(result) == memoryview(buf)
assert result.to_pybytes() == buf.to_pybytes()
assert result.is_mutable == buf.is_mutable


def test_buffer_bytes():
val = b'some data'

Expand All @@ -336,13 +346,22 @@ def test_buffer_bytes():
assert buf.is_cpu

result = buf.to_pybytes()

assert result == val

# Check that buffers survive a pickle roundtrip
result_buf = pickle.loads(pickle.dumps(buf))
result = result_buf.to_pybytes()
assert result == val
check_buffer_pickling(buf)


def test_buffer_null_data():
null_buff = pa.foreign_buffer(address=0, size=0)
assert null_buff.to_pybytes() == b""
assert null_buff.address == 0
# ARROW-16048: we shouldn't expose a NULL address through the Python
# buffer protocol.
m = memoryview(null_buff)
assert m.tobytes() == b""
assert pa.py_buffer(m).address != 0

check_buffer_pickling(null_buff)


def test_buffer_memoryview():
Expand All @@ -354,9 +373,10 @@ def test_buffer_memoryview():
assert buf.is_cpu

result = memoryview(buf)

assert result == val

check_buffer_pickling(buf)


def test_buffer_bytearray():
val = bytearray(b'some data')
Expand All @@ -367,9 +387,10 @@ def test_buffer_bytearray():
assert buf.is_cpu

result = bytearray(buf)

assert result == val

check_buffer_pickling(buf)


def test_buffer_invalid():
with pytest.raises(TypeError,
Expand Down

0 comments on commit 1dccb56

Please sign in to comment.