forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add serialization * add serialization * add serialization * lalalalalalalala * lalalalalalalala * serialize * serialize * nnn * WIP: import tvm runtime node system * WIP: object system * containers * tested basic container composition * tested custom object * tmp * fix setattr bug * tested object container return * fix lint * some comments about get/set state * fix lint * fix lint * update cython * fix cython * ffi doc * fix doc * WIP: using object system for graph * c++ side refactoring done; compiled * remove stale apis * fix bug in DGLGraphCreate; passed test_graph.py * fix bug in python modify; passed utest for pytorch/cpu * fix lint * Add serialization * Add serialization * fix * fix typo * serialize with new ffi * commit * commit * commit * save * save * save * save * commit * clean * Delete tt2.py * fix lint * Add serialization * fix lint 2 * fix lint * fix lint * fix lint * fix lint * Fix Lint * Add serialization * Change to Macro * fix * fix * fix bugs * refactor * refactor * updating dmlc-core to include force flag * trying tempfile * delete leaked pointer * Fix assert * fix assert * add comment and test case * add graph labels * add load labels * lint * lint * add graph labels * lint * fix windows * fix * update dmlc-core to latest * fix * fix camel naming
- Loading branch information
Showing
12 changed files
with
795 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
"""For Graph Serialization""" | ||
from __future__ import absolute_import | ||
from ..graph import DGLGraph | ||
from ..batched_graph import BatchedDGLGraph | ||
from .._ffi.object import ObjectBase, register_object | ||
from .._ffi.function import _init_api | ||
from .. import backend as F | ||
|
||
_init_api("dgl.data.graph_serialize") | ||
|
||
__all__ = ['save_graphs', "load_graphs", "load_labels"] | ||
|
||
@register_object("graph_serialize.StorageMetaData") | ||
class StorageMetaData(ObjectBase): | ||
"""StorageMetaData Object | ||
attributes available: | ||
num_graph [int]: return numbers of graphs | ||
nodes_num_list Value of NDArray: return number of nodes for each graph | ||
edges_num_list Value of NDArray: return number of edges for each graph | ||
labels [dict of backend tensors]: return dict of labels | ||
graph_data [list of GraphData]: return list of GraphData Object | ||
""" | ||
|
||
|
||
@register_object("graph_serialize.GraphData") | ||
class GraphData(ObjectBase): | ||
"""GraphData Object""" | ||
|
||
@staticmethod | ||
def create(g: DGLGraph): | ||
"""Create GraphData""" | ||
assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization" | ||
ghandle = g._graph | ||
if len(g.ndata) != 0: | ||
node_tensors = dict() | ||
for key, value in g.ndata.items(): | ||
node_tensors[key] = F.zerocopy_to_dgl_ndarray(value) | ||
else: | ||
node_tensors = None | ||
|
||
if len(g.edata) != 0: | ||
edge_tensors = dict() | ||
for key, value in g.edata.items(): | ||
edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value) | ||
else: | ||
edge_tensors = None | ||
|
||
return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors) | ||
|
||
def get_graph(self): | ||
"""Get DGLGraph from GraphData""" | ||
ghandle = _CAPI_GDataGraphHandle(self) | ||
g = DGLGraph(graph_data=ghandle, readonly=True) | ||
node_tensors_items = _CAPI_GDataNodeTensors(self).items() | ||
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items() | ||
for k, v in node_tensors_items: | ||
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v.data) | ||
for k, v in edge_tensors_items: | ||
g.edata[k] = F.zerocopy_from_dgl_ndarray(v.data) | ||
return g | ||
|
||
|
||
def save_graphs(filename, g_list, labels=None): | ||
r""" | ||
Save DGLGraphs and graph labels to file | ||
Parameters | ||
---------- | ||
filename : str | ||
File name to store DGLGraphs. | ||
g_list: list | ||
DGLGraph or list of DGLGraph | ||
labels: dict (Default: None) | ||
labels should be dict of tensors/ndarray, with str as keys | ||
Examples | ||
---------- | ||
>>> import dgl | ||
>>> import torch as th | ||
Create :code:`DGLGraph` objects and initialize node and edge features. | ||
>>> g1 = dgl.DGLGraph() | ||
>>> g1.add_nodes(3) | ||
>>> g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) | ||
>>> g1.ndata["e"] = th.ones(3, 5) | ||
>>> g2 = dgl.DGLGraph() | ||
>>> g2.add_nodes(3) | ||
>>> g2.add_edges([0, 1, 2], [1, 2, 1]) | ||
>>> g1.edata["e"] = th.ones(3, 4) | ||
Save Graphs into file | ||
>>> from dgl.data.utils import save_graphs | ||
>>> graph_labels = {"glabel": th.tensor([0, 1])} | ||
>>> save_graphs([g1, g2], "./data.bin", graph_labels) | ||
""" | ||
if isinstance(g_list, DGLGraph): | ||
g_list = [g_list] | ||
if (labels is not None) and (len(labels) != 0): | ||
label_dict = dict() | ||
for key, value in labels.items(): | ||
label_dict[key] = F.zerocopy_to_dgl_ndarray(value) | ||
else: | ||
label_dict = None | ||
gdata_list = [GraphData.create(g) for g in g_list] | ||
_CAPI_DGLSaveGraphs(filename, gdata_list, label_dict) | ||
|
||
|
||
def load_graphs(filename, idx_list=None): | ||
""" | ||
Load DGLGraphs from file | ||
Parameters | ||
---------- | ||
filename: str | ||
filename to load DGLGraphs | ||
idx_list: list of int | ||
list of index of graph to be loaded. If not specified, will | ||
load all graphs from file | ||
Returns | ||
---------- | ||
graph_list: list of immutable DGLGraphs | ||
labels: dict of labels stored in file (empty dict returned if no | ||
label stored) | ||
Examples | ||
---------- | ||
Following the example in save_graphs. | ||
>>> from dgl.utils.data import load_graphs | ||
>>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2] | ||
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] | ||
""" | ||
assert isinstance(idx_list, list) | ||
if idx_list is None: | ||
idx_list = [] | ||
metadata = _CAPI_DGLLoadGraphs(filename, idx_list, False) | ||
label_dict = {} | ||
for k, v in metadata.labels.items(): | ||
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data) | ||
|
||
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict | ||
|
||
|
||
def load_labels(filename): | ||
""" | ||
Load label dict from file | ||
Parameters | ||
---------- | ||
filename: str | ||
filename to load DGLGraphs | ||
Returns | ||
---------- | ||
labels: dict | ||
dict of labels stored in file (empty dict returned if no | ||
label stored) | ||
Examples | ||
---------- | ||
Following the example in save_graphs. | ||
>>> from dgl.data.utils import load_labels | ||
>>> label_dict = load_graphs("./data.bin") | ||
""" | ||
metadata = _CAPI_DGLLoadGraphs(filename, [], True) | ||
label_dict = {} | ||
for k, v in metadata.labels.items(): | ||
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data) | ||
return label_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.