Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1520 from reyoung/feature/serialize_d…
Browse files Browse the repository at this point in the history
…eserialize_in_parameters

Add save/load parameters.
  • Loading branch information
reyoung authored Mar 6, 2017
2 parents 5f2cbce + c36a3f4 commit 963bd5d
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 8 deletions.
3 changes: 3 additions & 0 deletions demo/mnist/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ plot.png
train.log
*pyc
.ipynb_checkpoints
params.pkl
params.tar
params.tar.gz
23 changes: 18 additions & 5 deletions demo/mnist/api_train_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import paddle.v2 as paddle
import gzip


def softmax_regression(img):
Expand Down Expand Up @@ -71,7 +72,11 @@ def main():

cost = paddle.layer.classification_cost(input=predict, label=label)

parameters = paddle.parameters.create(cost)
try:
with gzip.open('params.tar.gz', 'r') as f:
parameters = paddle.parameters.Parameters.from_tar(f)
except IOError:
parameters = paddle.parameters.create(cost)

optimizer = paddle.optimizer.Momentum(
learning_rate=0.1 / 128.0,
Expand All @@ -86,10 +91,18 @@ def main():

def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test(), batch_size=256))

print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)

with gzip.open('params.tar.gz', 'w') as f:
parameters.to_tar(f)

elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (
Expand Down
71 changes: 70 additions & 1 deletion python/paddle/v2/parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig

import struct
import tarfile
import cStringIO
from topology import Topology

__all__ = ['Parameters', 'create']
Expand Down Expand Up @@ -122,6 +124,12 @@ def __getitem__(self, key):

if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy.
if len(self.__tmp_params__) != 0:
ret_list = [
mat for name, mat in self.__tmp_params__ if name == key
]
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32)
else:
for each_gradient_machine in self.__gradient_machines__:
Expand Down Expand Up @@ -228,6 +236,67 @@ def append_gradient_machine(self, gradient_machine):

self.__gradient_machines__.append(gradient_machine)

def serialize(self, name, f):
"""
:param name:
:param f:
:type f: file
:return:
"""
param = self.get(name)
size = reduce(lambda a, b: a * b, param.shape)
f.write(struct.pack("IIQ", 0, 4, size))
param = param.astype(np.float32)
f.write(param.tobytes())

def deserialize(self, name, f):
"""
:param name:
:param f:
:type f: file
:return:
"""
f.read(16) # header
arr = np.frombuffer(f.read(), dtype=np.float32)
self.set(name, arr.reshape(self.get_shape(name)))

def to_tar(self, f):
tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names():
buf = cStringIO.StringIO()
self.serialize(nm, buf)
tarinfo = tarfile.TarInfo(name=nm)
buf.seek(0)
tarinfo.size = len(buf.getvalue())
tar.addfile(tarinfo, buf)

conf = self.__param_conf__[nm]
confStr = conf.SerializeToString()
tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
tarinfo.size = len(confStr)
buf = cStringIO.StringIO(confStr)
buf.seek(0)
tar.addfile(tarinfo, fileobj=buf)

@staticmethod
def from_tar(f):
params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar:
assert isinstance(finfo, tarfile.TarInfo)
if finfo.name.endswith('.protobuf'):
f = tar.extractfile(finfo)
conf = ParameterConfig()
conf.ParseFromString(f.read())
params.__append_config__(conf)

for param_name in params.names():
f = tar.extractfile(param_name)
params.deserialize(param_name, f)
return params


def __get_parameter_in_gradient_machine__(gradient_machine, name):
"""
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/v2/tests/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cd $SCRIPTPATH

$1 -m pip install ../../../../paddle/dist/*.whl

test_list="test_data_feeder.py"
test_list="test_data_feeder.py test_parameters.py"

export PYTHONPATH=$PWD/../../../../python/

Expand Down
60 changes: 60 additions & 0 deletions python/paddle/v2/tests/test_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
import sys

try:
import py_paddle

del py_paddle
except ImportError:
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
"unittest will not be run."
sys.exit(0)

import paddle.v2.parameters as parameters
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import random
import cStringIO
import numpy


def __rand_param_config__(name):
conf = ParameterConfig()
conf.name = name
size = 1
for i in xrange(2):
dim = random.randint(1, 1000)
conf.dims.append(dim)
size *= dim
conf.size = size
assert conf.IsInitialized()
return conf


class TestParameters(unittest.TestCase):
def test_serialization(self):
params = parameters.Parameters()
params.__append_config__(__rand_param_config__("param_0"))
params.__append_config__(__rand_param_config__("param_1"))

for name in params.names():
param = params.get(name)
param[:] = numpy.random.uniform(
-1.0, 1.0, size=params.get_shape(name))
params.set(name, param)

tmp_file = cStringIO.StringIO()
params.to_tar(tmp_file)
tmp_file.seek(0)
params_dup = parameters.Parameters.from_tar(tmp_file)

self.assertEqual(params_dup.names(), params.names())

for name in params.names():
self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
p0 = params.get(name)
p1 = params_dup.get(name)
self.assertTrue(numpy.isclose(p0, p1).all())


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(self, cost, parameters, update_equation):
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm)

def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
"""
Expand Down

0 comments on commit 963bd5d

Please sign in to comment.