Skip to content

Commit

Permalink
ENH: add options for disabling use of pickle in load/save
Browse files Browse the repository at this point in the history
  • Loading branch information
pv committed Apr 18, 2015
1 parent 0752872 commit a2bd3a7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
20 changes: 16 additions & 4 deletions numpy/lib/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _read_array_header(fp, version):

return d['shape'], d['fortran_order'], dtype

def write_array(fp, array, version=None, pickle_kwargs=None):
def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None):
"""
Write an array to an NPY file, including a header.
Expand All @@ -533,6 +533,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None):
version : (int, int) or None, optional
The version number of the format. None means use the oldest
supported version that is able to store the data. Default: None
allow_pickle : bool, optional
Whether to allow writing pickled data. Default: True
pickle_kwargs : dict, optional
Additional keyword arguments to pass to pickle.dump, excluding
'protocol'. These are only useful when pickling objects in object
Expand All @@ -541,7 +543,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None):
Raises
------
ValueError
If the array cannot be persisted.
If the array cannot be persisted. This includes the case of
allow_pickle=False and array being an object array.
Various other errors
If the array contains Python objects as part of its dtype, the
process of pickling them may raise various errors if the objects
Expand All @@ -563,6 +566,9 @@ def write_array(fp, array, version=None, pickle_kwargs=None):
# We contain Python objects so we cannot write out the data
# directly. Instead, we will pickle it out with version 2 of the
# pickle protocol.
if not allow_pickle:
raise ValueError("Object arrays cannot be saved when "
"allow_pickle=False")
if pickle_kwargs is None:
pickle_kwargs = {}
pickle.dump(array, fp, protocol=2, **pickle_kwargs)
Expand All @@ -584,7 +590,7 @@ def write_array(fp, array, version=None, pickle_kwargs=None):
fp.write(chunk.tobytes('C'))


def read_array(fp, pickle_kwargs=None):
def read_array(fp, allow_pickle=True, pickle_kwargs=None):
"""
Read an array from an NPY file.
Expand All @@ -593,6 +599,8 @@ def read_array(fp, pickle_kwargs=None):
fp : file_like object
If this is not a real file object, then this may take extra memory
and time.
allow_pickle : bool, optional
Whether to allow reading pickled data. Default: True
pickle_kwargs : dict
Additional keyword arguments to pass to pickle.load. These are only
useful when loading object arrays saved on Python 2 when using
Expand All @@ -606,7 +614,8 @@ def read_array(fp, pickle_kwargs=None):
Raises
------
ValueError
If the data is invalid.
If the data is invalid, or allow_pickle=False and the file contains
an object array.
"""
version = read_magic(fp)
Expand All @@ -620,6 +629,9 @@ def read_array(fp, pickle_kwargs=None):
# Now read the actual data.
if dtype.hasobject:
# The array contained Python objects. We need to unpickle the data.
if not allow_pickle:
raise ValueError("Object arrays cannot be loaded when "
"allow_pickle=False")
if pickle_kwargs is None:
pickle_kwargs = {}
try:
Expand Down
44 changes: 36 additions & 8 deletions numpy/lib/npyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class NpzFile(object):
f : BagObj instance
An object on which attribute can be performed as an alternative
to getitem access on the `NpzFile` instance itself.
allow_pickle : bool, optional
Allow loading pickled data. Default: True
pickle_kwargs : dict, optional
Additional keyword arguments to pass on to pickle.load.
These are only useful when loading object arrays saved on
Expand Down Expand Up @@ -199,12 +201,14 @@ class NpzFile(object):
"""

def __init__(self, fid, own_fid=False, pickle_kwargs=None):
def __init__(self, fid, own_fid=False, allow_pickle=True,
pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an
# optional component of the so-called standard library.
_zip = zipfile_factory(fid)
self._files = _zip.namelist()
self.files = []
self.allow_pickle = allow_pickle
self.pickle_kwargs = pickle_kwargs
for x in self._files:
if x.endswith('.npy'):
Expand Down Expand Up @@ -262,6 +266,7 @@ def __getitem__(self, key):
if magic == format.MAGIC_PREFIX:
bytes = self.zip.open(key)
return format.read_array(bytes,
allow_pickle=self.allow_pickle,
pickle_kwargs=self.pickle_kwargs)
else:
return self.zip.read(key)
Expand Down Expand Up @@ -295,7 +300,8 @@ def __contains__(self, key):
return self.files.__contains__(key)


def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
encoding='ASCII'):
"""
Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files.
Expand All @@ -312,6 +318,12 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
and sliced like any ndarray. Memory mapping is especially useful
for accessing small fragments of large files without reading the
entire file into memory.
allow_pickle : bool, optional
Allow loading pickled object arrays stored in npy files. Reasons for
disallowing pickles include security, as loading pickled data can
execute arbitrary code. If pickles are disallowed, loading object
arrays will fail.
Default: True
fix_imports : bool, optional
Only useful when loading Python 2 generated pickled files on Python 3,
which includes npy/npz files containing object arrays. If `fix_imports`
Expand All @@ -324,7 +336,6 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
'ASCII', and 'bytes' are not allowed, as they can corrupt numerical
data. Default: 'ASCII'
Returns
-------
result : array, tuple, dict, etc.
Expand All @@ -335,6 +346,8 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
------
IOError
If the input file does not exist or cannot be read.
ValueError
The file contains an object array, but allow_pickle=False given.
See Also
--------
Expand Down Expand Up @@ -430,15 +443,20 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
# Transfer file ownership to NpzFile
tmp = own_fid
own_fid = False
return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs)
return NpzFile(fid, own_fid=tmp, allow_pickle=allow_pickle,
pickle_kwargs=pickle_kwargs)
elif magic == format.MAGIC_PREFIX:
# .npy file
if mmap_mode:
return format.open_memmap(file, mode=mmap_mode)
else:
return format.read_array(fid, pickle_kwargs=pickle_kwargs)
return format.read_array(fid, allow_pickle=allow_pickle,
pickle_kwargs=pickle_kwargs)
else:
# Try a pickle
if not allow_pickle:
raise ValueError("allow_pickle=False, but file does not contain "
"non-pickled data")
try:
return pickle.load(fid, **pickle_kwargs)
except:
Expand All @@ -449,7 +467,7 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
fid.close()


def save(file, arr, fix_imports=True):
def save(file, arr, allow_pickle=True, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
Expand All @@ -460,6 +478,14 @@ def save(file, arr, fix_imports=True):
then the filename is unchanged. If file is a string, a ``.npy``
extension will be appended to the file name if it does not already
have one.
allow_pickle : bool, optional
Allow saving object arrays using Python pickles. Reasons for disallowing
pickles include security (loading pickled data can execute arbitrary
code) and portability (pickled objects may not be loadable on different
Python installations, for example if the stored objects require libraries
that are not available, and not all pickled data is compatible between
Python 2 and Python 3).
Default: True
fix_imports : bool, optional
Only useful in forcing objects in object arrays on Python 3 to be
pickled in a Python 2 compatible way. If `fix_imports` is True, pickle
Expand Down Expand Up @@ -509,7 +535,8 @@ def save(file, arr, fix_imports=True):

try:
arr = np.asanyarray(arr)
format.write_array(fid, arr, pickle_kwargs=pickle_kwargs)
format.write_array(fid, arr, allow_pickle=allow_pickle,
pickle_kwargs=pickle_kwargs)
finally:
if own_fid:
fid.close()
Expand Down Expand Up @@ -621,7 +648,7 @@ def savez_compressed(file, *args, **kwds):
_savez(file, args, kwds, True)


def _savez(file, args, kwds, compress, pickle_kwargs=None):
def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an optional
# component of the so-called standard library.
import zipfile
Expand Down Expand Up @@ -656,6 +683,7 @@ def _savez(file, args, kwds, compress, pickle_kwargs=None):
fid = open(tmpfile, 'wb')
try:
format.write_array(fid, np.asanyarray(val),
allow_pickle=allow_pickle,
pickle_kwargs=pickle_kwargs)
fid.close()
fid = None
Expand Down
16 changes: 16 additions & 0 deletions numpy/lib/tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,22 @@ def test_pickle_python2_python3():
encoding='latin1', fix_imports=False)


def test_pickle_disallow():
data_dir = os.path.join(os.path.dirname(__file__), 'data')

path = os.path.join(data_dir, 'py2-objarr.npy')
assert_raises(ValueError, np.load, path,
allow_pickle=False, encoding='latin1')

path = os.path.join(data_dir, 'py2-objarr.npz')
f = np.load(path, allow_pickle=False, encoding='latin1')
assert_raises(ValueError, f.__getitem__, 'x')

path = os.path.join(tempdir, 'pickle-disabled.npy')
assert_raises(ValueError, np.save, path, np.array([None], dtype=object),
allow_pickle=False)


def test_version_2_0():
f = BytesIO()
# requires more than 2 byte for header
Expand Down

0 comments on commit a2bd3a7

Please sign in to comment.