Skip to content

Commit

Permalink
[Feature] Data format (dmlc#728)
Browse files Browse the repository at this point in the history
* Add serialization

* add serialization

* add serialization

* lalalalalalalala

* lalalalalalalala

* serialize

* serialize

* nnn

* WIP: import tvm runtime node system

* WIP: object system

* containers

* tested basic container composition

* tested custom object

* tmp

* fix setattr bug

* tested object container return

* fix lint

* some comments about get/set state

* fix lint

* fix lint

* update cython

* fix cython

* ffi doc

* fix doc

* WIP: using object system for graph

* c++ side refactoring done; compiled

* remove stale apis

* fix bug in DGLGraphCreate; passed test_graph.py

* fix bug in python modify; passed utest for pytorch/cpu

* fix lint

* Add serialization

* Add serialization

* fix

* fix typo

* serialize with new ffi

* commit

* commit

* commit

* save

* save

* save

* save

* commit

* clean

* Delete tt2.py

* fix lint

* Add serialization

* fix lint 2

* fix lint

* fix lint

* fix lint

* fix lint

* Fix Lint

* Add serialization

* Change to Macro

* fix

* fix

* fix bugs

* refactor

* refactor

* updating dmlc-core to include force flag

* trying tempfile

* delete leaked pointer

* Fix assert

* fix assert

* add comment and test case

* add graph labels

* add load labels

* lint

* lint

* add graph labels

* lint

*  fix windows

* fix

* update dmlc-core to latest

* fix

* fix camel naming
  • Loading branch information
VoVAllen authored Sep 9, 2019
1 parent 6a4b5ae commit 0fb13f7
Show file tree
Hide file tree
Showing 12 changed files with 795 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ else(USE_CUDA)
add_library(dgl SHARED ${DGL_SRC})
endif(USE_CUDA)

# For serialization
add_subdirectory("third_party/dmlc-core")
list(APPEND DGL_LINKER_LIBS dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test

target_link_libraries(dgl ${DGL_LINKER_LIBS} ${DGL_RUNTIME_LINKER_LIBS})

# Installation rules
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/python/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Utils
utils.check_sha1
utils.extract_archive
utils.split_dataset
utils.save_graphs
utils.load_graphs
utils.load_labels

.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
Expand Down
2 changes: 1 addition & 1 deletion include/dgl/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2019 by Contributors
* \file packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass ObjectRef types into/from PackedFunc.
* This enables pass ObjectRef types into/from PackedFunc.
*/
#ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_
Expand Down
2 changes: 1 addition & 1 deletion include/dgl/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class IterAdapter {
* values, try use the constructor to create the list at once (for example
* from an existing vector).
*
* operator[] only provide const acces, use Set to mutate the content.
* operator[] only provide const access, use Set to mutate the content.
*
* \tparam T The content ObjectRef type.
*/
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StrMap(Map):
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
return [(akvs[i].data, akvs[i+1]) for i in range(0, len(akvs), 2)]

@register_object
class Value(ObjectBase):
Expand Down
176 changes: 176 additions & 0 deletions python/dgl/data/graph_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""For Graph Serialization"""
from __future__ import absolute_import
from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F

_init_api("dgl.data.graph_serialize")

__all__ = ['save_graphs', "load_graphs", "load_labels"]

@register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase):
"""StorageMetaData Object
attributes available:
num_graph [int]: return numbers of graphs
nodes_num_list Value of NDArray: return number of nodes for each graph
edges_num_list Value of NDArray: return number of edges for each graph
labels [dict of backend tensors]: return dict of labels
graph_data [list of GraphData]: return list of GraphData Object
"""


@register_object("graph_serialize.GraphData")
class GraphData(ObjectBase):
"""GraphData Object"""

@staticmethod
def create(g: DGLGraph):
"""Create GraphData"""
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization"
ghandle = g._graph
if len(g.ndata) != 0:
node_tensors = dict()
for key, value in g.ndata.items():
node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
else:
node_tensors = None

if len(g.edata) != 0:
edge_tensors = dict()
for key, value in g.edata.items():
edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
else:
edge_tensors = None

return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)

def get_graph(self):
"""Get DGLGraph from GraphData"""
ghandle = _CAPI_GDataGraphHandle(self)
g = DGLGraph(graph_data=ghandle, readonly=True)
node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items:
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v.data)
for k, v in edge_tensors_items:
g.edata[k] = F.zerocopy_from_dgl_ndarray(v.data)
return g


def save_graphs(filename, g_list, labels=None):
r"""
Save DGLGraphs and graph labels to file
Parameters
----------
filename : str
File name to store DGLGraphs.
g_list: list
DGLGraph or list of DGLGraph
labels: dict (Default: None)
labels should be dict of tensors/ndarray, with str as keys
Examples
----------
>>> import dgl
>>> import torch as th
Create :code:`DGLGraph` objects and initialize node and edge features.
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(3)
>>> g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> g1.ndata["e"] = th.ones(3, 5)
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 1])
>>> g1.edata["e"] = th.ones(3, 4)
Save Graphs into file
>>> from dgl.data.utils import save_graphs
>>> graph_labels = {"glabel": th.tensor([0, 1])}
>>> save_graphs([g1, g2], "./data.bin", graph_labels)
"""
if isinstance(g_list, DGLGraph):
g_list = [g_list]
if (labels is not None) and (len(labels) != 0):
label_dict = dict()
for key, value in labels.items():
label_dict[key] = F.zerocopy_to_dgl_ndarray(value)
else:
label_dict = None
gdata_list = [GraphData.create(g) for g in g_list]
_CAPI_DGLSaveGraphs(filename, gdata_list, label_dict)


def load_graphs(filename, idx_list=None):
"""
Load DGLGraphs from file
Parameters
----------
filename: str
filename to load DGLGraphs
idx_list: list of int
list of index of graph to be loaded. If not specified, will
load all graphs from file
Returns
----------
graph_list: list of immutable DGLGraphs
labels: dict of labels stored in file (empty dict returned if no
label stored)
Examples
----------
Following the example in save_graphs.
>>> from dgl.utils.data import load_graphs
>>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2]
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
"""
assert isinstance(idx_list, list)
if idx_list is None:
idx_list = []
metadata = _CAPI_DGLLoadGraphs(filename, idx_list, False)
label_dict = {}
for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)

return [gdata.get_graph() for gdata in metadata.graph_data], label_dict


def load_labels(filename):
"""
Load label dict from file
Parameters
----------
filename: str
filename to load DGLGraphs
Returns
----------
labels: dict
dict of labels stored in file (empty dict returned if no
label stored)
Examples
----------
Following the example in save_graphs.
>>> from dgl.data.utils import load_labels
>>> label_dict = load_graphs("./data.bin")
"""
metadata = _CAPI_DGLLoadGraphs(filename, [], True)
label_dict = {}
for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)
return label_dict
6 changes: 5 additions & 1 deletion python/dgl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import zipfile
import tarfile
import numpy as np

from .graph_serialize import save_graphs, load_graphs, load_labels

try:
import requests
except ImportError:
Expand All @@ -16,7 +19,8 @@ class requests_failed_to_import(object):
requests = requests_failed_to_import

__all__ = ['download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset']
'get_download_dir', 'Subset', 'split_dataset',
'save_graphs', "load_graphs", "load_labels"]


def _get_dgl_url(file_url):
Expand Down
Loading

0 comments on commit 0fb13f7

Please sign in to comment.