Skip to content

Commit

Permalink
[Bug] Fix dmlc#1088 (dmlc#1089)
Browse files Browse the repository at this point in the history
* [Bug] Fix dmlc#1088

* fix

* add comment
  • Loading branch information
BarclayII authored Dec 11, 2019
1 parent 48c7ec4 commit 6ae93e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
if (hg->NumEdgeTypes() == 1) {
CHECK_EQ(etype, 0);
*rv = hg;
} else {
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
// Test if the heterograph is a unit graph. If so, return itself.
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg.sptr());
if (bg != nullptr)
*rv = bg;
else
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
});

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
Expand Down
9 changes: 9 additions & 0 deletions tests/compute/test_hetero_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,14 @@ def foo(g):
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)

def test_issue_1088():
# This test ensures that message passing on a heterograph with one edge type
# would not crash (GitHub issue #1088).
import dgl.function as fn
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])})
g.nodes['U'].data['x'] = F.randn((3, 3))
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))

if __name__ == '__main__':
test_nx_conversion()
test_batch_setter_getter()
Expand All @@ -781,3 +789,4 @@ def foo(g):
test_group_apply_edges()
test_local_var()
test_local_scope()
test_issue_1088()

0 comments on commit 6ae93e5

Please sign in to comment.