Skip to content

Commit

Permalink
[performance] Batch DGLGraph in C++ end. (dmlc#2155)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* fix

* upd

Co-authored-by: VoVAllen <[email protected]>
  • Loading branch information
yzh119 and VoVAllen authored Sep 10, 2020
1 parent ac570c1 commit cbd55eb
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 128 deletions.
3 changes: 2 additions & 1 deletion python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def to_backend_ctx(dglctx):


def astype(input, ty):
return tf.cast(input, dtype=ty)
with tf.device(input.device):
return tf.cast(input, dtype=ty)


def asnumpy(input):
Expand Down
45 changes: 15 additions & 30 deletions python/dgl/batch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from collections import defaultdict

from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning
from .heterograph_index import disjoint_union
from .heterograph import DGLHeteroGraph
from . import convert
from . import utils


__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero']

def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
Expand Down Expand Up @@ -156,61 +158,44 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
dgl_warning('Arguments edge_attrs has been deprecated. Please use'
' edata instead.')
edata = edge_attrs
if not (is_all(ndata) or isinstance(ndata, list)):
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata)))
if not (is_all(edata) or isinstance(edata, list)):
if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format(
type(edata)))
if any(g.is_block for g in graphs):
raise DGLError("Batching a block is not supported.")

utils.check_all_same_device(graphs, 'graphs')
utils.check_all_same_idtype(graphs, 'graphs')
relations = graphs[0].canonical_etypes
ntypes = graphs[0].ntypes
idtype = graphs[0].idtype
device = graphs[0].device

# Batch graph structure for each relation graph
edge_dict = defaultdict(list)
num_nodes_dict = defaultdict(int)
for g in graphs:
for rel in relations:
srctype, etype, dsttype = rel
u, v = g.edges(order='eid', etype=rel)
src = u + num_nodes_dict[srctype]
dst = v + num_nodes_dict[dsttype]
edge_dict[rel].append((src, dst))
for ntype in ntypes:
num_nodes_dict[ntype] += g.number_of_nodes(ntype)
for rel in relations:
src, dst = zip(*edge_dict[rel])
edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0))
retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device)
relations = list(sorted(graphs[0].canonical_etypes))
ntypes = list(sorted(graphs[0].ntypes))
etypes = [etype for _, etype, _ in relations]

gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs])
retg = DGLHeteroGraph(gidx, ntypes, etypes)

# Compute batch num nodes
bnn = {}
for ntype in graphs[0].ntypes:
for ntype in ntypes:
bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0)
retg.set_batch_num_nodes(bnn)

# Compute batch num edges
bne = {}
for etype in graphs[0].canonical_etypes:
for etype in relations:
bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0)
retg.set_batch_num_edges(bne)

# Batch node feature
if ndata is not None:
for ntype in graphs[0].ntypes:
for ntype in ntypes:
feat_dicts = [g.nodes[ntype].data for g in graphs if g.number_of_nodes(ntype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, ndata, 'nodes["{}"].data'.format(ntype))
retg.nodes[ntype].data.update(ret_feat)

# Batch edge feature
if edata is not None:
for etype in graphs[0].canonical_etypes:
for etype in relations:
feat_dicts = [g.edges[etype].data for g in graphs if g.number_of_edges(etype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, edata, 'edges[{}].data'.format(etype))
retg.edges[etype].data.update(ret_feat)
Expand Down
1 change: 0 additions & 1 deletion src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
}

NDArray Concat(const std::vector<IdArray>& arrays) {
CHECK(arrays.size() > 1) << "Number of arrays should larger than 1";
IdArray ret;

int64_t len = 0, offset = 0;
Expand Down
4 changes: 3 additions & 1 deletion src/array/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace cuda {
* and is also power of two.
*/
inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
CHECK_NE(dim, 0);
CHECK_GE(dim, 0);
if (dim == 0)
return 1;
int ret = max_nthrs;
while (ret > dim) {
ret = ret >> 1;
Expand Down
4 changes: 0 additions & 4 deletions src/array/union_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ namespace dgl {
namespace aten {
///////////////////////// COO Based Operations/////////////////////////
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
CHECK(coos.size() > 1) <<
"The length of input COOMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0;
int64_t edge_data_offset = 0;
bool has_data = false;
Expand Down Expand Up @@ -114,8 +112,6 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(

///////////////////////// CSR Based Operations/////////////////////////
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
CHECK(csrs.size() > 1) <<
"The length of input CSRMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0;
int64_t indices_offset = 0;
bool has_data = false;
Expand Down
4 changes: 4 additions & 0 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPt
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
}
auto ctx = rel_graphs[0]->Context();
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
}
}

std::vector<int64_t>
Expand Down
22 changes: 0 additions & 22 deletions src/graph/heterograph_capi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,28 +559,6 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v
*rv = ret_list;
});

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
<< "Expect graph list has at least one graph";
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to batch have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
}
ATEN_ID_BITS_SWITCH(bits, IdType, {
auto hgptr =
DisjointUnionHeteroGraph<IdType>(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
});

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
Expand Down
77 changes: 12 additions & 65 deletions src/graph/transform/union_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,34 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);

// Loop over all ntypes
for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
uint64_t offset = 0;
for (const auto &cg : component_graphs)
offset += cg->NumVertices(vtype);
num_nodes_per_type[vtype] = offset;
}

// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
uint64_t src_offset = 0, dst_offset = 0;
HeteroGraphPtr rgptr = nullptr;

const dgl_format_code_t code =\
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// do some preprocess
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
for (const auto &cg : component_graphs) {
const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats();
if (cur_code != code)
LOG(FATAL) << "All components should have the same formats";

// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}

// prefer COO
if (FORMAT_HAS_COO(code)) {
std::vector<aten::COOMatrix> coos;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
for (const auto &cg : component_graphs) {
aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo);
}
Expand All @@ -128,8 +129,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
for (const auto &cg : component_graphs) {
aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr);
}
Expand All @@ -141,8 +141,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
} else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
for (const auto &cg : component_graphs) {
aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc);
}
Expand All @@ -152,8 +151,6 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
}
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
}

return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
Expand Down Expand Up @@ -272,56 +269,6 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
return rst;
}

template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);

// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
IdType src_offset = 0, dst_offset = 0;
std::vector<IdType> result_src, result_dst;

// Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);

// Loop over all edges
for (size_t j = 0; j < num_edges; ++j) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[j] + src_offset);
result_dst.push_back(edges_dst_data[j] + dst_offset);
}
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, src_offset, dst_offset,
aten::VecToIdArray(result_src, sizeof(IdType) * 8),
aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}

template HeteroGraphPtr DisjointUnionHeteroGraph<int32_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);

template HeteroGraphPtr DisjointUnionHeteroGraph<int64_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);

template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
if (size > 0)
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/compute/test_batched_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,6 @@ def test_unbatch2(idtype):
#test_topology('int32')
#test_batching_batched('int32')
#test_batched_features('int32')
#test_empty_relation('int32')
# test_empty_relation('int64')
#test_to_device('int32')
pass

0 comments on commit cbd55eb

Please sign in to comment.