Skip to content

Commit

Permalink
[NN] Support Unidirectional Bipartite Graphs in CFConv (dmlc#2674)
Browse files Browse the repository at this point in the history
* Update

* update

* Update

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
mufeili and VoVAllen authored Mar 3, 2021
1 parent 97bdae9 commit 7380d61
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
20 changes: 14 additions & 6 deletions python/dgl/nn/pytorch/conv/cfconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,23 @@ def forward(self, g, node_feats, edge_feats):
----------
g : DGLGraph
The graph.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
node_feats : torch.Tensor or pair of torch.Tensor
The input node features. If a torch.Tensor is given, it represents the input
node feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of
input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, which is the case for bipartite graph,
the pair must contain two tensors of shape :math:`(N_{src}, D_{in_{src}})` and
:math:`(N_{dst}, D_{in_{dst}})` separately for the source and destination nodes.
edge_feats : torch.Tensor
The input edge feature of shape :math:`(E, edge_in_feats)`
where :math:`E` is the number of edges.
Returns
-------
float32 tensor of shape (V, out_feats)
Updated node representations.
torch.Tensor
The output node feature of shape :math:`(N_{out}, out_feats)`
where :math:`N_{out}` is the number of destination nodes.
"""
with g.local_scope():
if isinstance(node_feats, tuple):
Expand Down
12 changes: 9 additions & 3 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def test_atomic_conv(g, idtype):
assert h.shape[-1] == 4

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 3])
def test_cf_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx())
Expand All @@ -936,9 +936,15 @@ def test_cf_conv(g, idtype, out_dim):
if F.gpu_ctx():
cfconv = cfconv.to(ctx)

node_feats = F.randn((g.number_of_nodes(), 2))
src_feats = F.randn((g.number_of_src_nodes(), 2))
edge_feats = F.randn((g.number_of_edges(), 3))
h = cfconv(g, node_feats, edge_feats)
h = cfconv(g, src_feats, edge_feats)
# current we only do shape check
assert h.shape[-1] == out_dim

# case for bipartite graphs
dst_feats = F.randn((g.number_of_dst_nodes(), 3))
h = cfconv(g, (src_feats, dst_feats), edge_feats)
# current we only do shape check
assert h.shape[-1] == out_dim

Expand Down

0 comments on commit 7380d61

Please sign in to comment.