forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#1520 from reyoung/feature/serialize_d…
…eserialize_in_parameters Add save/load parameters.
- Loading branch information
Showing
6 changed files
with
153 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,6 @@ plot.png | |
train.log | ||
*pyc | ||
.ipynb_checkpoints | ||
params.pkl | ||
params.tar | ||
params.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import unittest | ||
import sys | ||
|
||
try: | ||
import py_paddle | ||
|
||
del py_paddle | ||
except ImportError: | ||
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \ | ||
"unittest will not be run." | ||
sys.exit(0) | ||
|
||
import paddle.v2.parameters as parameters | ||
from paddle.proto.ParameterConfig_pb2 import ParameterConfig | ||
import random | ||
import cStringIO | ||
import numpy | ||
|
||
|
||
def __rand_param_config__(name): | ||
conf = ParameterConfig() | ||
conf.name = name | ||
size = 1 | ||
for i in xrange(2): | ||
dim = random.randint(1, 1000) | ||
conf.dims.append(dim) | ||
size *= dim | ||
conf.size = size | ||
assert conf.IsInitialized() | ||
return conf | ||
|
||
|
||
class TestParameters(unittest.TestCase): | ||
def test_serialization(self): | ||
params = parameters.Parameters() | ||
params.__append_config__(__rand_param_config__("param_0")) | ||
params.__append_config__(__rand_param_config__("param_1")) | ||
|
||
for name in params.names(): | ||
param = params.get(name) | ||
param[:] = numpy.random.uniform( | ||
-1.0, 1.0, size=params.get_shape(name)) | ||
params.set(name, param) | ||
|
||
tmp_file = cStringIO.StringIO() | ||
params.to_tar(tmp_file) | ||
tmp_file.seek(0) | ||
params_dup = parameters.Parameters.from_tar(tmp_file) | ||
|
||
self.assertEqual(params_dup.names(), params.names()) | ||
|
||
for name in params.names(): | ||
self.assertEqual(params.get_shape(name), params_dup.get_shape(name)) | ||
p0 = params.get(name) | ||
p1 = params_dup.get(name) | ||
self.assertTrue(numpy.isclose(p0, p1).all()) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters