Skip to content

Commit

Permalink
[Fix] Fix stream creation and empty dict in tensor serialization (dml…
Browse files Browse the repository at this point in the history
…c#1489)

* add functions

* fix litn

* add unit test

* fix

* fix

* fix

* fix

* support empty dict

* simplify logic

* Update tensor_serialize.py
  • Loading branch information
VoVAllen authored May 2, 2020
1 parent f25bc17 commit 30b8074
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
18 changes: 15 additions & 3 deletions python/dgl/data/tensor_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""For Tensor Serialization"""
from __future__ import absolute_import
from ..ndarray import NDArray
from .._ffi.function import _init_api
from .. import backend as F

Expand All @@ -18,20 +19,26 @@ def save_tensors(filename, tensor_dict):
File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value
Returns
----------
status : bool
Return whether save operation succeeds
"""
nd_dict = {}
is_empty_dict = len(tensor_dict) == 0
for key, value in tensor_dict.items():
if not isinstance(key, str):
raise Exception("Dict key has to be str")
if F.is_tensor(value):
nd_dict[key] = F.zerocopy_to_dgl_ndarray(value)
elif isinstance(value, nd.NDArray):
elif isinstance(value, NDArray):
nd_dict[key] = value
else:
raise Exception(
"Dict value has to be backend tensor or dgl ndarray")

return _CAPI_SaveNDArrayDict(filename, nd_dict)
return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict)


def load_tensors(filename, return_dgl_ndarray=False):
Expand All @@ -44,6 +51,11 @@ def load_tensors(filename, return_dgl_ndarray=False):
File name to load dict of tensors.
return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors
Returns
---------
tensor_dict : dict
dict of tensor or ndarray based on return_dgl_ndarray flag
"""
nd_dict = _CAPI_LoadNDArrayDict(filename)
tensor_dict = {}
Expand Down
26 changes: 19 additions & 7 deletions src/graph/serialize/tensor_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,48 @@ namespace serialize {

typedef std::pair<std::string, NDArray> NamedTensor;

constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;

DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
Map<std::string, Value> nd_dict = args[1];
auto *fs = dmlc::Stream::Create(filename.c_str(), "w");
CHECK(fs) << "Filename is invalid";
fs->Write(kDGLSerialize_Tensors);
bool empty_dict = args[2];
Map<std::string, Value> nd_dict;
if (!empty_dict) {
nd_dict = args[1];
}
std::vector<NamedTensor> namedTensors;
fs->Write(static_cast<uint64_t>(nd_dict.size()));
for (auto kv : nd_dict) {
NDArray ndarray = static_cast<NDArray>(kv.second->data);
namedTensors.emplace_back(kv.first, ndarray);
}
auto *fs = dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true));
fs->Write(namedTensors);
delete fs;
*rv = true;
delete fs;
});

DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "r");
CHECK(fs) << "Filename is invalid or file doesn't exists";
uint64_t magincNum, num_elements;
CHECK(fs->Read(&magincNum)) << "Invalid file";
CHECK_EQ(magincNum, kDGLSerialize_Tensors) << "Invalid DGL tensor file";
CHECK(fs->Read(&num_elements)) << "Invalid num of elements";
Map<std::string, Value> nd_dict;
std::vector<NamedTensor> namedTensors;
SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true);
CHECK(fs) << "Filename is invalid or file doesn't exists";
fs->Read(&namedTensors);
for (auto kv : namedTensors) {
Value ndarray = Value(MakeValue(kv.second));
nd_dict.Set(kv.first, ndarray);
}
delete fs;
*rv = nd_dict;
delete fs;
});

} // namespace serialize
Expand Down
16 changes: 16 additions & 0 deletions tests/compute/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,25 @@ def test_serialize_tensors():

os.unlink(path)

def test_serialize_empty_dict():
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()

tensor_dict = {}

save_tensors(path, tensor_dict)

load_tensor_dict = load_tensors(path)
assert isinstance(load_tensor_dict, dict)
assert len(load_tensor_dict) == 0

os.unlink(path)

if __name__ == "__main__":
test_graph_serialize_with_feature()
test_graph_serialize_without_feature()
test_graph_serialize_with_labels()
test_serialize_tensors()
test_serialize_empty_dict()

0 comments on commit 30b8074

Please sign in to comment.