diff --git a/python/dgl/data/tensor_serialize.py b/python/dgl/data/tensor_serialize.py index 61dffea4e796..d53d7bd26874 100644 --- a/python/dgl/data/tensor_serialize.py +++ b/python/dgl/data/tensor_serialize.py @@ -1,5 +1,6 @@ """For Tensor Serialization""" from __future__ import absolute_import +from ..ndarray import NDArray from .._ffi.function import _init_api from .. import backend as F @@ -18,20 +19,26 @@ def save_tensors(filename, tensor_dict): File name to store dict of tensors. tensor_dict: dict of dgl NDArray or backend tensor Python dict using string as key and tensor as value + + Returns + ---------- + status : bool + Return whether save operation succeeds """ nd_dict = {} + is_empty_dict = len(tensor_dict) == 0 for key, value in tensor_dict.items(): if not isinstance(key, str): raise Exception("Dict key has to be str") if F.is_tensor(value): nd_dict[key] = F.zerocopy_to_dgl_ndarray(value) - elif isinstance(value, nd.NDArray): + elif isinstance(value, NDArray): nd_dict[key] = value else: raise Exception( "Dict value has to be backend tensor or dgl ndarray") - - return _CAPI_SaveNDArrayDict(filename, nd_dict) + + return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict) def load_tensors(filename, return_dgl_ndarray=False): @@ -44,6 +51,11 @@ def load_tensors(filename, return_dgl_ndarray=False): File name to load dict of tensors. return_dgl_ndarray: bool Whether return dict of dgl NDArrays or backend tensors + + Returns + --------- + tensor_dict : dict + dict of tensor or ndarray based on return_dgl_ndarray flag """ nd_dict = _CAPI_LoadNDArrayDict(filename) tensor_dict = {} diff --git a/src/graph/serialize/tensor_serialize.cc b/src/graph/serialize/tensor_serialize.cc index 9211e768105f..4e2d21e45734 100644 --- a/src/graph/serialize/tensor_serialize.cc +++ b/src/graph/serialize/tensor_serialize.cc @@ -19,36 +19,48 @@ namespace serialize { typedef std::pair NamedTensor; +constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F; + DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") .set_body([](DGLArgs args, DGLRetValue *rv) { std::string filename = args[0]; - Map nd_dict = args[1]; + auto *fs = dmlc::Stream::Create(filename.c_str(), "w"); + CHECK(fs) << "Filename is invalid"; + fs->Write(kDGLSerialize_Tensors); + bool empty_dict = args[2]; + Map nd_dict; + if (!empty_dict) { + nd_dict = args[1]; + } std::vector namedTensors; + fs->Write(static_cast(nd_dict.size())); for (auto kv : nd_dict) { NDArray ndarray = static_cast(kv.second->data); namedTensors.emplace_back(kv.first, ndarray); } - auto *fs = dynamic_cast( - SeekStream::Create(filename.c_str(), "w", true)); fs->Write(namedTensors); - delete fs; *rv = true; + delete fs; }); DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") .set_body([](DGLArgs args, DGLRetValue *rv) { std::string filename = args[0]; + auto *fs = dmlc::Stream::Create(filename.c_str(), "r"); + CHECK(fs) << "Filename is invalid or file doesn't exists"; + uint64_t magincNum, num_elements; + CHECK(fs->Read(&magincNum)) << "Invalid file"; + CHECK_EQ(magincNum, kDGLSerialize_Tensors) << "Invalid DGL tensor file"; + CHECK(fs->Read(&num_elements)) << "Invalid num of elements"; Map nd_dict; std::vector namedTensors; - SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true); - CHECK(fs) << "Filename is invalid or file doesn't exists"; fs->Read(&namedTensors); for (auto kv : namedTensors) { Value ndarray = Value(MakeValue(kv.second)); nd_dict.Set(kv.first, ndarray); } - delete fs; *rv = nd_dict; + delete fs; }); } // namespace serialize diff --git a/tests/compute/test_serialize.py b/tests/compute/test_serialize.py index 58b099072a1c..f3521680163f 100644 --- a/tests/compute/test_serialize.py +++ b/tests/compute/test_serialize.py @@ -162,9 +162,25 @@ def test_serialize_tensors(): os.unlink(path) +def test_serialize_empty_dict(): + # create a temporary file and immediately release it so DGL can open it. + f = tempfile.NamedTemporaryFile(delete=False) + path = f.name + f.close() + + tensor_dict = {} + + save_tensors(path, tensor_dict) + + load_tensor_dict = load_tensors(path) + assert isinstance(load_tensor_dict, dict) + assert len(load_tensor_dict) == 0 + + os.unlink(path) if __name__ == "__main__": test_graph_serialize_with_feature() test_graph_serialize_without_feature() test_graph_serialize_with_labels() test_serialize_tensors() + test_serialize_empty_dict() \ No newline at end of file