Skip to content

Commit

Permalink
[BUG] Fix dmlc#1409 (dmlc#1411)
Browse files Browse the repository at this point in the history
* [BUG] Fix dmlc#1409

* fix test
  • Loading branch information
BarclayII authored Apr 3, 2020
1 parent d3560b7 commit 24dc71f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
44 changes: 40 additions & 4 deletions src/graph/unit_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -999,9 +999,27 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
// We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
HeteroSubgraph ret;
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, subcsr, nullptr));

CSRPtr subcsr = nullptr;
CSRPtr subcsc = nullptr;
COOPtr subcoo = nullptr;
switch (fmt) {
case SparseFormat::kCSR:
subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCSC:
subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCOO:
subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
break;
default:
LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
return ret;
}

ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges);
return ret;
Expand All @@ -1011,9 +1029,27 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
HeteroSubgraph ret;
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, nullptr, subcoo));

CSRPtr subcsr = nullptr;
CSRPtr subcsc = nullptr;
COOPtr subcoo = nullptr;
switch (fmt) {
case SparseFormat::kCSR:
subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCSC:
subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCOO:
subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
break;
default:
LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
return ret;
}

ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges);
return ret;
Expand Down
12 changes: 12 additions & 0 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,18 @@ def _check_typed_subgraph2(g, sg):
sg5 = g.edge_type_subgraph(['follows', 'plays', 'wishes'])
_check_typed_subgraph1(g, sg5)

# Test for restricted format
for fmt in ['csr', 'csc', 'coo']:
g = dgl.graph([(0, 1), (1, 2)], restrict_format=fmt)
sg = g.subgraph({g.ntypes[0]: [1, 0]})
nids = F.asnumpy(sg.ndata[dgl.NID])
assert np.array_equal(nids, np.array([1, 0]))
src, dst = sg.all_edges(order='eid')
src = F.asnumpy(src)
dst = F.asnumpy(dst)
assert np.array_equal(src, np.array([1]))
assert np.array_equal(dst, np.array([0]))

def test_apply():
def node_udf(nodes):
return {'h': nodes.data['h'] * 2}
Expand Down

0 comments on commit 24dc71f

Please sign in to comment.