Skip to content

Commit

Permalink
Save/Load models to binary stream (keras-team#11708)
Browse files Browse the repository at this point in the history
* add help functions for load save to binary

* move h5py binary helpers to io_utils, integrate in load/save_model

* add docs of new functionality

* add tests for saving loading models binary

* fix PR comments, improve docs

* remove remaining unnecessary blank lines

* fix minor codestyle

* improve docs

* fix review comments

* remove H5Dict.opens_file

* align docs, weights -> model

* Style fixes in docstrings

* Fix PEP8.
  • Loading branch information
andhus authored and fchollet committed Mar 30, 2019
1 parent ad578c4 commit c3eb627
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 75 deletions.
6 changes: 5 additions & 1 deletion keras/engine/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,11 @@ def save(self, filepath, overwrite=True, include_optimizer=True):
was never compiled in the first place).
# Arguments
filepath: String, path to the file to save the weights to.
filepath: one of the following:
- string, path to the file to save the model to
- h5py.File or h5py.Group object where to save the model
- any file-like object implementing the method `write` that accepts
`bytes` data (e.g. `io.BytesIO`).
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
Expand Down
54 changes: 31 additions & 23 deletions keras/engine/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from .. import backend as K
from .. import optimizers
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.io_utils import H5Dict
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.io_utils import save_to_binary_h5py
from ..utils.io_utils import load_from_binary_h5py
from ..utils import conv_utils

try:
Expand Down Expand Up @@ -475,8 +477,10 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
# Arguments
model: Keras model instance to be saved.
filepath: one of the following:
- string, path where to save the model, or
- string, path to the file to save the model to
- h5py.File or h5py.Group object where to save the model
- any file-like object implementing the method `write` that accepts
`bytes` data (e.g. `io.BytesIO`).
overwrite: Whether we should overwrite any existing
model at the target location, or instead
ask the user with a manual prompt.
Expand All @@ -488,22 +492,21 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
if h5py is None:
raise ImportError('`save_model` requires h5py.')

if not isinstance(filepath, h5py.Group):
# If file exists and should not be overwritten.
if not overwrite and os.path.isfile(filepath):
if H5Dict.is_supported_type(filepath):
opens_file = not isinstance(filepath, (dict, h5py.Group))
if opens_file and os.path.isfile(filepath) and not overwrite:
proceed = ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
opened_new_file = True
with H5Dict(filepath, mode='w') as h5dict:
_serialize_model(model, h5dict, include_optimizer)
elif hasattr(filepath, 'write') and callable(filepath.write):
# write as binary stream
def save_function(h5file):
_serialize_model(model, H5Dict(h5file), include_optimizer)
save_to_binary_h5py(save_function, filepath)
else:
opened_new_file = False

h5dict = H5Dict(filepath, mode='w')
try:
_serialize_model(model, h5dict, include_optimizer)
finally:
if opened_new_file:
h5dict.close()
raise ValueError('unexpected type {} for `filepath`'.format(type(filepath)))


@allow_read_from_gcs
Expand All @@ -512,8 +515,10 @@ def load_model(filepath, custom_objects=None, compile=True):
# Arguments
filepath: one of the following:
- string, path to the saved model, or
- string, path to the saved model
- h5py.File or h5py.Group object from which to load the model
- any file-like object implementing the method `read` that returns
`bytes` data (e.g. `io.BytesIO`) that represents a valid h5py file image.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
Expand All @@ -534,14 +539,17 @@ def load_model(filepath, custom_objects=None, compile=True):
"""
if h5py is None:
raise ImportError('`load_model` requires h5py.')
model = None
opened_new_file = not isinstance(filepath, h5py.Group)
h5dict = H5Dict(filepath, 'r')
try:
model = _deserialize_model(h5dict, custom_objects, compile)
finally:
if opened_new_file:
h5dict.close()

if H5Dict.is_supported_type(filepath):
with H5Dict(filepath, mode='r') as h5dict:
model = _deserialize_model(h5dict, custom_objects, compile)
elif hasattr(filepath, 'write') and callable(filepath.write):
def load_function(h5file):
return _deserialize_model(H5Dict(h5file), custom_objects, compile)
model = load_from_binary_h5py(load_function, filepath)
else:
raise ValueError('unexpected type {} for `filepath`'.format(type(filepath)))

return model


Expand Down
82 changes: 75 additions & 7 deletions keras/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from collections import defaultdict
import sys
import contextlib


import six
Expand Down Expand Up @@ -183,16 +184,10 @@ class H5Dict(object):
"""

def __init__(self, path, mode='a'):
def is_path_instance(path):
# We can't use isinstance here because it would require
# us to add pathlib2 to the Python 2 dependencies.
class_name = type(path).__name__
return class_name == 'PosixPath' or class_name == 'WindowsPath'

if isinstance(path, h5py.Group):
self.data = path
self._is_file = False
elif isinstance(path, six.string_types) or is_path_instance(path):
elif isinstance(path, six.string_types) or _is_path_instance(path):
self.data = h5py.File(path, mode=mode)
self._is_file = True
elif isinstance(path, dict):
Expand All @@ -207,6 +202,16 @@ def is_path_instance(path):
'Received: {}.'.format(type(path)))
self.read_only = mode == 'r'

@staticmethod
def is_supported_type(path):
"""Check if `path` is of supported type for instantiating a `H5Dict`"""
return (
isinstance(path, h5py.Group) or
isinstance(path, dict) or
isinstance(path, six.string_types) or
_is_path_instance(path)
)

def __setitem__(self, attr, val):
if self.read_only:
raise ValueError('Cannot set item in read-only mode.')
Expand Down Expand Up @@ -358,5 +363,68 @@ def get(self, key, default=None):
return self[key]
return default

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


h5dict = H5Dict


def load_from_binary_h5py(load_function, stream):
"""Calls `load_function` on a `h5py.File` read from the binary `stream`.
# Arguments
load_function: A function that takes a `h5py.File`, reads from it, and
returns any object.
stream: Any file-like object implementing the method `read` that returns
`bytes` data (e.g. `io.BytesIO`) that represents a valid h5py file image.
# Returns
The object returned by `load_function`.
"""
# Implementation based on suggestion solution here:
# https://github.com/keras-team/keras/issues/9343#issuecomment-440903847
binary_data = stream.read()
file_access_property_list = h5py.h5p.create(h5py.h5p.FILE_ACCESS)
file_access_property_list.set_fapl_core(backing_store=False)
file_access_property_list.set_file_image(binary_data)
file_id_args = {'fapl': file_access_property_list,
'flags': h5py.h5f.ACC_RDONLY,
'name': b'in-memory-h5py'} # name does not matter
h5_file_args = {'backing_store': False,
'driver': 'core',
'mode': 'r'}
with contextlib.closing(h5py.h5f.open(**file_id_args)) as file_id:
with h5py.File(file_id, **h5_file_args) as h5_file:
return load_function(h5_file)


def save_to_binary_h5py(save_function, stream):
"""Calls `save_function` on an in memory `h5py.File`.
The file is subsequently written to the binary `stream`.
# Arguments
save_function: A function that takes a `h5py.File`, writes to it and
(optionally) returns any object.
stream: Any file-like object implementing the method `write` that accepts
`bytes` data (e.g. `io.BytesIO`).
"""
with h5py.File('in-memory-h5py', driver='core', backing_store=False) as h5file:
# note that filename does not matter here.
return_value = save_function(h5file)
h5file.flush()
binary_data = h5file.fid.get_file_image()
stream.write(binary_data)

return return_value


def _is_path_instance(path):
# We can't use isinstance here because it would require
# us to add pathlib2 to the Python 2 dependencies.
class_name = type(path).__name__
return class_name == 'PosixPath' or class_name == 'WindowsPath'
117 changes: 117 additions & 0 deletions tests/keras/utils/io_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
'''Tests for functions in io_utils.py.
'''
import os
import io
import pytest

from contextlib import contextmanager

from keras.models import Sequential
from keras.layers import Dense
from keras.utils.io_utils import HDF5Matrix
from keras.utils.io_utils import H5Dict
from keras.utils.io_utils import ask_to_proceed_with_overwrite
from keras.utils.io_utils import save_to_binary_h5py
from keras.utils.io_utils import load_from_binary_h5py
from numpy.testing import assert_allclose
from numpy.testing import assert_array_equal
import numpy as np
import six
import h5py
Expand Down Expand Up @@ -232,5 +239,115 @@ def test_H5Dict_accepts_pathlib_Path():
os.remove(h5_path)


@contextmanager
def temp_filename(suffix):
"""Context that returns a temporary filename and deletes the file on exit if
it still exists (so that this is not forgotten).
"""
_, temp_fname = tempfile.mkstemp(suffix=suffix)
yield temp_fname
if os.path.exists(temp_fname):
os.remove(temp_fname)


def test_save_to_binary_h5py_direct_to_file():
data = np.random.random((3, 5))

def save_function(h5file_):
h5file_['data'] = data

with temp_filename('.h5') as fname:
with open(fname, 'wb') as f:
save_to_binary_h5py(save_function, f)

with h5py.File(fname) as h5file:
data_rec = h5file['data'][:]

assert_array_equal(data_rec, data)


def test_save_to_binary_h5py_to_bytes_io():
data = np.random.random((3, 5))

def save_function(h5file_):
h5file_['data'] = data

file_like = io.BytesIO()
save_to_binary_h5py(save_function, file_like)

file_like.seek(0)

with temp_filename('.h5') as fname:
with open(fname, 'wb') as f:
f.write(file_like.read())

with h5py.File(fname) as h5file:
data_rec = h5file['data'][:]

assert_array_equal(data_rec, data)


def test_load_from_binary_h5py_direct_from_file():
data = np.random.random((3, 5))

def load_function(h5file_):
return h5file_['data'][:]

with temp_filename('.h5') as fname:
with h5py.File(fname, 'w') as h5file:
h5file['data'] = data

with open(fname, 'rb') as f:
data_rec = load_from_binary_h5py(load_function, f)

assert_array_equal(data_rec, data)


def test_load_from_binary_h5py_from_bytes_io():
data = np.random.random((3, 5))

def load_function(h5file_):
return h5file_['data'][:]

with temp_filename('.h5') as fname:
with h5py.File(fname, 'w') as h5file:
h5file['data'] = data

file_like = io.BytesIO()
with open(fname, 'rb') as f:
file_like.write(f.read())

file_like.seek(0)
data_rec = load_from_binary_h5py(load_function, file_like)

assert_array_equal(data_rec, data)


def test_save_load_binary_h5py():

data1 = np.random.random((3, 5))
data2 = np.random.random((2, 3, 5))
attr = 1
datas = [data1, data2, attr]

def save_function(h5file_):
h5file_['data1'] = data1
h5file_['subgroup/data2'] = data2
h5file_['data1'].attrs['attr'] = attr

def load_function(h5file_):
d1 = h5file_['data1'][:]
d2 = h5file_['subgroup/data2'][:]
a = h5file_['data1'].attrs['attr']
return d1, d2, a

file_like = io.BytesIO()
save_to_binary_h5py(save_function, file_like)
file_like.seek(0)
datas_rec = load_from_binary_h5py(load_function, file_like)
for d_rec, d in zip(datas_rec, datas):
assert_array_equal(d_rec, d)


if __name__ == '__main__':
pytest.main([__file__])
Loading

0 comments on commit c3eb627

Please sign in to comment.