Skip to content

Commit

Permalink
[Pickle] Fix HeteroGraphConv pickle problem (dmlc#2761)
Browse files Browse the repository at this point in the history
* fix pickle problem

* lint

* add pickle tests

* fix

* fix

* fix

* fix

* fix for windows
  • Loading branch information
VoVAllen authored Mar 18, 2021
1 parent 337b155 commit 366cc7e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 15 deletions.
46 changes: 31 additions & 15 deletions python/dgl/nn/pytorch/hetero.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Heterograph NN modules"""
from functools import partial
import torch as th
import torch.nn as nn
from ...base import DGLError

__all__ = ['HeteroGraphConv']

Expand Down Expand Up @@ -196,6 +198,29 @@ def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
rsts[nty] = self.agg_fn(alist, nty)
return rsts

def _max_reduce_func(inputs, dim):
return th.max(inputs, dim=dim)[0]

def _min_reduce_func(inputs, dim):
return th.min(inputs, dim=dim)[0]

def _sum_reduce_func(inputs, dim):
return th.sum(inputs, dim=dim)

def _mean_reduce_func(inputs, dim):
return th.mean(inputs, dim=dim)

def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)

def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)

def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
Expand All @@ -213,28 +238,19 @@ def get_aggregate_fn(agg):
and returns one aggregated tensor.
"""
if agg == 'sum':
fn = th.sum
fn = _sum_reduce_func
elif agg == 'max':
fn = lambda inputs, dim: th.max(inputs, dim=dim)[0]
fn = _max_reduce_func
elif agg == 'min':
fn = lambda inputs, dim: th.min(inputs, dim=dim)[0]
fn = _min_reduce_func
elif agg == 'mean':
fn = th.mean
fn = _mean_reduce_func
elif agg == 'stack':
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)
return stack_agg
return _stack_agg_func
else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)
return aggfn
return partial(_agg_func, fn=fn)
57 changes: 57 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import torch as th
import networkx as nx
import dgl
Expand All @@ -8,9 +9,12 @@
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
from copy import deepcopy
import pickle

import scipy as sp

tmp_buffer = io.BytesIO()

def _AXWb(A, X, W, b):
X = th.matmul(X, W)
Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
Expand All @@ -25,6 +29,11 @@ def test_graph_conv0(out_dim):
conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
conv = conv.to(ctx)
print(conv)

# test pickle
th.save(conv, tmp_buffer)


# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(g, h0)
Expand Down Expand Up @@ -119,6 +128,10 @@ def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())

# test pickle
th.save(conv, tmp_buffer)

ext_w = F.randn((5, out_dim)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
Expand All @@ -141,6 +154,10 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
# Test a pair of tensor inputs
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())

# test pickle
th.save(conv, tmp_buffer)

ext_w = F.randn((5, out_dim)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
Expand Down Expand Up @@ -175,6 +192,9 @@ def test_tagconv(out_dim):
conv = nn.TAGConv(5, out_dim, bias=True)
conv = conv.to(ctx)
print(conv)

# test pickle
th.save(conv, tmp_buffer)

# test#1: basic
h0 = F.ones((3, 5))
Expand Down Expand Up @@ -231,6 +251,9 @@ def test_glob_att_pool():
gap = gap.to(ctx)
print(gap)

# test pickle
th.save(gap, tmp_buffer)

# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
Expand Down Expand Up @@ -347,6 +370,10 @@ def test_rgcn(O):
I = 10

rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)

# test pickle
th.save(rgc_basis, tmp_buffer)

rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
Expand Down Expand Up @@ -509,6 +536,10 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
feat = F.randn((g.number_of_nodes(), 5))
gat = gat.to(ctx)
h = gat(g, feat)

# test pickle
th.save(gat, tmp_buffer)

assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)
Expand Down Expand Up @@ -536,6 +567,8 @@ def test_sage_conv(idtype, g, aggre_type):
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((g.number_of_nodes(), 5))
sage = sage.to(F.ctx())
# test pickle
th.save(sage, tmp_buffer)
h = sage(g, feat)
assert h.shape[-1] == 10

Expand Down Expand Up @@ -583,6 +616,10 @@ def test_sgc_conv(g, idtype, out_dim):
g = g.astype(idtype).to(ctx)
# not cached
sgc = nn.SGConv(5, out_dim, 3)

# test pickle
th.save(sgc, tmp_buffer)

feat = F.randn((g.number_of_nodes(), 5))
sgc = sgc.to(ctx)

Expand All @@ -605,6 +642,9 @@ def test_appnp_conv(g, idtype):
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((g.number_of_nodes(), 5))
appnp = appnp.to(ctx)

# test pickle
th.save(appnp, tmp_buffer)

h = appnp(g, feat)
assert h.shape[-1] == 5
Expand All @@ -622,6 +662,10 @@ def test_gin_conv(g, idtype, aggregator_type):
feat = F.randn((g.number_of_nodes(), 5))
gin = gin.to(ctx)
h = gin(g, feat)

# test pickle
th.save(h, tmp_buffer)

assert h.shape == (g.number_of_nodes(), 12)

@parametrize_dtype
Expand Down Expand Up @@ -784,6 +828,10 @@ def test_edge_conv(g, idtype, out_dim):
ctx = F.ctx()
edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
print(edge_conv)

# test pickle
th.save(edge_conv, tmp_buffer)

h0 = F.randn((g.number_of_nodes(), 5))
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim)
Expand Down Expand Up @@ -811,6 +859,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads):
dotgat = nn.DotGatConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5))
dotgat = dotgat.to(ctx)

# test pickle
th.save(dotgat, tmp_buffer)

h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = dotgat(g, feat, get_attention=True)
Expand Down Expand Up @@ -919,6 +971,7 @@ def test_atomic_conv(g, idtype):
dist = F.randn((g.number_of_edges(), 1))

h = aconv(g, feat, dist)

# current we only do shape check
assert h.shape[-1] == 4

Expand Down Expand Up @@ -968,6 +1021,10 @@ def test_hetero_conv(agg, idtype):
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
conv = conv.to(F.ctx())

# test pickle
th.save(conv, tmp_buffer)

uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
Expand Down

0 comments on commit 366cc7e

Please sign in to comment.