Skip to content

Commit

Permalink
[Bugfix] Allows node types with no inbound/outbound edges (dmlc#1323)
Browse files Browse the repository at this point in the history
* add num nodes in ctors

* fix

* lint

* addresses comments

* replace with constexpr

* remove function with rvalue reference

* address comments
  • Loading branch information
BarclayII authored Mar 7, 2020
1 parent ac74233 commit ce6e19f
Show file tree
Hide file tree
Showing 21 changed files with 304 additions and 101 deletions.
18 changes: 16 additions & 2 deletions include/dgl/base_heterograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class BaseHeteroGraph : public runtime::Object {
/*! \return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;

/*! \return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {};
}

/*! \return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0;

Expand Down Expand Up @@ -543,9 +549,14 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);

// Declarations of functions and algorithms

/*! \brief Create a heterograph from meta graph and a list of bipartite graph */
/*!
* \brief Create a heterograph from meta graph and a list of bipartite graph,
* additionally specifying number of nodes per type.
*/
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr> &rel_graphs,
const std::vector<int64_t> &num_nodes_per_type = {});

/*!
* \brief Create a heterograph from COO input.
Expand Down Expand Up @@ -651,6 +662,9 @@ struct HeteroPickleStates : public runtime::Object {
/*! \brief Metagraph. */
GraphPtr metagraph;

/*! \brief Number of nodes per type */
std::vector<int64_t> num_nodes_per_type;

/*! \brief adjacency matrices of each relation graph */
std::vector<std::shared_ptr<SparseMatrix> > adjs;

Expand Down
45 changes: 44 additions & 1 deletion include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,45 @@
#include "serializer.h"
#include "shared_mem.h"

/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}

/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}

namespace dgl {

/*!
* \brief Type traits that converts a C type to a DLDataType.
*
* Usage:
* DLDataTypeTraits<int>::dtype == dtype
*/
template<typename T>
struct DLDataTypeTraits {
static constexpr DLDataType dtype{0, 0, 0}; // dummy
};
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \
template<> \
struct DLDataTypeTraits<T> { \
static constexpr DLDataType dtype{code, bits, 1}; \
}
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR

namespace runtime {

/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
Expand Down Expand Up @@ -191,8 +228,14 @@ class NDArray {
DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});

/*!
* \brief Create a std::vector from a 1D NDArray.
* \tparam T Type of vector data.
* \note Type casting is NOT performed. The caller has to make sure that the vector
* type matches the dtype of NDArray.
*/
template<typename T>
static NDArray FromVector(const std::vector<T>& vec, DLDataType dtype, DLContext ctx);
std::vector<T> ToVector() const;

/*!
* \brief Function to copy data from one array to another.
Expand Down
47 changes: 29 additions & 18 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
else:
raise DGLError('Unsupported graph data type:', type(data))

def hetero_from_relations(rel_graphs):
def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
"""Create a heterograph from graphs representing connections of each relation.
The input is a list of heterographs where the ``i``th graph contains edges of type
Expand All @@ -294,6 +294,9 @@ def hetero_from_relations(rel_graphs):
----------
rel_graphs : list of DGLHeteroGraph
Each element corresponds to a heterograph for one (src, edge, dst) relation.
num_nodes_per_type : dict[str, Tensor], optional
Number of nodes per node type. If not given, DGL will infer the number of nodes
from the given relation graphs.
Returns
-------
Expand Down Expand Up @@ -349,32 +352,35 @@ def hetero_from_relations(rel_graphs):
# TODO(minjie): this API can be generalized as a union operation of the input graphs
# TODO(minjie): handle node/edge data
# infer meta graph
ntype_set = set()
meta_edges = []
meta_edges_src, meta_edges_dst = [], []
ntypes = []
etypes = []
# TODO(BarclayII): I'm keeping the node type names sorted because even if
# the metagraph is the same, the same node type name in different graphs may
# map to different node type IDs.
# In the future, we need to lower the type names into C++.
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
ntype_set.add(stype)
ntype_set.add(dtype)
ntypes = list(sorted(ntype_set))
if num_nodes_per_type is None:
ntype_set = set()
for rgrh in rel_graphs:
assert len(rgrh.etypes) == 1
stype, etype, dtype = rgrh.canonical_etypes[0]
ntype_set.add(stype)
ntype_set.add(dtype)
ntypes = list(sorted(ntype_set))
else:
ntypes = list(sorted(num_nodes_per_type.keys()))
num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes])
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
for rgrh in rel_graphs:
stype, etype, dtype = rgrh.canonical_etypes[0]
stid = ntype_dict[stype]
dtid = ntype_dict[dtype]
meta_edges.append((stid, dtid))
meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype)
metagraph = graph_index.from_edge_list(meta_edges, True, True)
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True, True)

# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs])
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
retg = DGLHeteroGraph(hgidx, ntypes, etypes)
for i, rgrh in enumerate(rel_graphs):
for ntype in rgrh.ntypes:
Expand Down Expand Up @@ -441,8 +447,12 @@ def heterograph(data_dict, num_nodes_dict=None):
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc
elif isinstance(data, DGLHeteroGraph):
# Do nothing; handled in the next loop
continue
# original node type and edge type of ``data`` is ignored.
assert len(data.canonical_etypes) == 1, \
"Relational graphs must have only one edge type."
srctype, _, dsttype = data.canonical_etypes[0]
nsrc = data.number_of_nodes(srctype)
ndst = data.number_of_nodes(dsttype)
else:
raise DGLError('Unsupported graph data type %s for %s' % (
type(data), (srctype, etype, dsttype)))
Expand All @@ -464,7 +474,7 @@ def heterograph(data_dict, num_nodes_dict=None):
data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))

return hetero_from_relations(rel_graphs)
return hetero_from_relations(rel_graphs, num_nodes_dict)

def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given homogeneous graph to a heterogeneous graph.
Expand Down Expand Up @@ -622,7 +632,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
card=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph)

hg = hetero_from_relations(rel_graphs)
hg = hetero_from_relations(
rel_graphs, {ntype: count for ntype, count in zip(ntypes, ntype_count)})

ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes):
Expand Down
8 changes: 6 additions & 2 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,7 @@ def node_type_subgraph(self, ntypes):
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = []

num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
for i in range(len(self._etypes)):
Expand All @@ -1697,7 +1698,8 @@ def node_type_subgraph(self, ntypes):
edge_frames.append(self._edge_frames[i])

metagraph = graph_index.from_edge_list(meta_edges, True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
return hg

Expand Down Expand Up @@ -1767,9 +1769,11 @@ def edge_type_subgraph(self, etypes):
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]

metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
hgidx = heterograph_index.create_heterograph_from_relations(metagraph, rel_graphs)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg

Expand Down
28 changes: 23 additions & 5 deletions python/dgl/heterograph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg
indptr.todgltensor(), indices.todgltensor(), edge_ids.todgltensor(),
restrict_format)

def create_heterograph_from_relations(metagraph, rel_graphs):
def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type):
"""Create a heterograph from metagraph and graphs of every relation.
Parameters
Expand All @@ -1017,12 +1017,18 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
Meta-graph.
rel_graphs : list of HeteroGraphIndex
Bipartite graph of each relation.
num_nodes_per_type : utils.Index, optional
Number of nodes per node type
Returns
-------
HeteroGraphIndex
"""
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
if num_nodes_per_type is None:
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
else:
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor())

def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs.
Expand Down Expand Up @@ -1085,6 +1091,17 @@ def metagraph(self):
"""
return _CAPI_DGLHeteroPickleStatesGetMetagraph(self)

@property
def num_nodes_per_type(self):
"""Number of nodes per edge type
Returns
-------
Tensor
Array of number of nodes for each type
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLHeteroPickleStatesGetNumVertices(self))

@property
def adjs(self):
"""Adjacency matrices of all the relation graphs
Expand All @@ -1097,11 +1114,12 @@ def adjs(self):
return list(_CAPI_DGLHeteroPickleStatesGetAdjs(self))

def __getstate__(self):
return self.metagraph, self.adjs
return self.metagraph, self.num_nodes_per_type, self.adjs

def __setstate__(self, state):
metagraph, adjs = state
metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, metagraph, adjs)
_CAPI_DGLCreateHeteroPickleStates, metagraph, num_nodes_per_type, adjs)

_init_api("dgl.heterograph_index")
4 changes: 4 additions & 0 deletions src/array/cpu/array_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class IdHashMap {
return values;
}

inline size_t Size() const {
return oldv2newv_.size();
}

private:
static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1;
Expand Down
14 changes: 7 additions & 7 deletions src/array/cpu/spmat_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
}
}
}
return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->ctx);
}

template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
Expand Down Expand Up @@ -228,7 +228,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
}
}

return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
return NDArray::FromVector(ret_vec, csr.data->ctx);
}

template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
Expand Down Expand Up @@ -306,9 +306,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
}
}

return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)};
return {NDArray::FromVector(ret_rows, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->ctx)};
}

template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
Expand Down Expand Up @@ -536,8 +536,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols,
NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx),
NDArray::FromVector(sub_indptr, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->ctx),
sub_data_arr};
}

Expand Down
10 changes: 0 additions & 10 deletions src/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@

using dgl::runtime::operator<<;

/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}

/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}

/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << ctx.device_type << ":" << ctx.device_id;
Expand Down
6 changes: 4 additions & 2 deletions src/graph/creators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ namespace dgl {

// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
}

HeteroGraphPtr CreateFromCOO(
Expand Down
Loading

0 comments on commit ce6e19f

Please sign in to comment.