From 366cc7ebf3eb3d4312c60fdbb80a83a88fe318df Mon Sep 17 00:00:00 2001 From: Jinjing Zhou Date: Thu, 18 Mar 2021 17:56:28 +0800 Subject: [PATCH] [Pickle] Fix HeteroGraphConv pickle problem (#2761) * fix pickle problem * lint * add pickle tests * fix * fix * fix * fix * fix for windows --- python/dgl/nn/pytorch/hetero.py | 46 +++++++++++++++++--------- tests/pytorch/test_nn.py | 57 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 15 deletions(-) diff --git a/python/dgl/nn/pytorch/hetero.py b/python/dgl/nn/pytorch/hetero.py index b77b854186fa..f89ec3bc26da 100644 --- a/python/dgl/nn/pytorch/hetero.py +++ b/python/dgl/nn/pytorch/hetero.py @@ -1,6 +1,8 @@ """Heterograph NN modules""" +from functools import partial import torch as th import torch.nn as nn +from ...base import DGLError __all__ = ['HeteroGraphConv'] @@ -196,6 +198,29 @@ def forward(self, g, inputs, mod_args=None, mod_kwargs=None): rsts[nty] = self.agg_fn(alist, nty) return rsts +def _max_reduce_func(inputs, dim): + return th.max(inputs, dim=dim)[0] + +def _min_reduce_func(inputs, dim): + return th.min(inputs, dim=dim)[0] + +def _sum_reduce_func(inputs, dim): + return th.sum(inputs, dim=dim) + +def _mean_reduce_func(inputs, dim): + return th.mean(inputs, dim=dim) + +def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument + if len(inputs) == 0: + return None + return th.stack(inputs, dim=1) + +def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument + if len(inputs) == 0: + return None + stacked = th.stack(inputs, dim=0) + return fn(stacked, dim=0) + def get_aggregate_fn(agg): """Internal function to get the aggregation function for node data generated from different relations. @@ -213,28 +238,19 @@ def get_aggregate_fn(agg): and returns one aggregated tensor. """ if agg == 'sum': - fn = th.sum + fn = _sum_reduce_func elif agg == 'max': - fn = lambda inputs, dim: th.max(inputs, dim=dim)[0] + fn = _max_reduce_func elif agg == 'min': - fn = lambda inputs, dim: th.min(inputs, dim=dim)[0] + fn = _min_reduce_func elif agg == 'mean': - fn = th.mean + fn = _mean_reduce_func elif agg == 'stack': fn = None # will not be called else: raise DGLError('Invalid cross type aggregator. Must be one of ' '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) if agg == 'stack': - def stack_agg(inputs, dsttype): # pylint: disable=unused-argument - if len(inputs) == 0: - return None - return th.stack(inputs, dim=1) - return stack_agg + return _stack_agg_func else: - def aggfn(inputs, dsttype): # pylint: disable=unused-argument - if len(inputs) == 0: - return None - stacked = th.stack(inputs, dim=0) - return fn(stacked, dim=0) - return aggfn + return partial(_agg_func, fn=fn) diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index a747beb681dc..284af91af155 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -1,3 +1,4 @@ +import io import torch as th import networkx as nx import dgl @@ -8,9 +9,12 @@ from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph from test_utils import parametrize_dtype from copy import deepcopy +import pickle import scipy as sp +tmp_buffer = io.BytesIO() + def _AXWb(A, X, W, b): X = th.matmul(X, W) Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X) @@ -25,6 +29,11 @@ def test_graph_conv0(out_dim): conv = nn.GraphConv(5, out_dim, norm='none', bias=True) conv = conv.to(ctx) print(conv) + + # test pickle + th.save(conv, tmp_buffer) + + # test#1: basic h0 = F.ones((3, 5)) h1 = conv(g, h0) @@ -119,6 +128,10 @@ def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim): def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim): g = g.astype(idtype).to(F.ctx()) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx()) + + # test pickle + th.save(conv, tmp_buffer) + ext_w = F.randn((5, out_dim)).to(F.ctx()) nsrc = g.number_of_src_nodes() ndst = g.number_of_dst_nodes() @@ -141,6 +154,10 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim): # Test a pair of tensor inputs g = g.astype(idtype).to(F.ctx()) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx()) + + # test pickle + th.save(conv, tmp_buffer) + ext_w = F.randn((5, out_dim)).to(F.ctx()) nsrc = g.number_of_src_nodes() ndst = g.number_of_dst_nodes() @@ -175,6 +192,9 @@ def test_tagconv(out_dim): conv = nn.TAGConv(5, out_dim, bias=True) conv = conv.to(ctx) print(conv) + + # test pickle + th.save(conv, tmp_buffer) # test#1: basic h0 = F.ones((3, 5)) @@ -231,6 +251,9 @@ def test_glob_att_pool(): gap = gap.to(ctx) print(gap) + # test pickle + th.save(gap, tmp_buffer) + # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = gap(g, h0) @@ -347,6 +370,10 @@ def test_rgcn(O): I = 10 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) + + # test pickle + th.save(rgc_basis, tmp_buffer) + rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx) rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.w_comp = rgc_basis.w_comp @@ -509,6 +536,10 @@ def test_gat_conv(g, idtype, out_dim, num_heads): feat = F.randn((g.number_of_nodes(), 5)) gat = gat.to(ctx) h = gat(g, feat) + + # test pickle + th.save(gat, tmp_buffer) + assert h.shape == (g.number_of_nodes(), num_heads, out_dim) _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) @@ -536,6 +567,8 @@ def test_sage_conv(idtype, g, aggre_type): sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((g.number_of_nodes(), 5)) sage = sage.to(F.ctx()) + # test pickle + th.save(sage, tmp_buffer) h = sage(g, feat) assert h.shape[-1] == 10 @@ -583,6 +616,10 @@ def test_sgc_conv(g, idtype, out_dim): g = g.astype(idtype).to(ctx) # not cached sgc = nn.SGConv(5, out_dim, 3) + + # test pickle + th.save(sgc, tmp_buffer) + feat = F.randn((g.number_of_nodes(), 5)) sgc = sgc.to(ctx) @@ -605,6 +642,9 @@ def test_appnp_conv(g, idtype): appnp = nn.APPNPConv(10, 0.1) feat = F.randn((g.number_of_nodes(), 5)) appnp = appnp.to(ctx) + + # test pickle + th.save(appnp, tmp_buffer) h = appnp(g, feat) assert h.shape[-1] == 5 @@ -622,6 +662,10 @@ def test_gin_conv(g, idtype, aggregator_type): feat = F.randn((g.number_of_nodes(), 5)) gin = gin.to(ctx) h = gin(g, feat) + + # test pickle + th.save(h, tmp_buffer) + assert h.shape == (g.number_of_nodes(), 12) @parametrize_dtype @@ -784,6 +828,10 @@ def test_edge_conv(g, idtype, out_dim): ctx = F.ctx() edge_conv = nn.EdgeConv(5, out_dim).to(ctx) print(edge_conv) + + # test pickle + th.save(edge_conv, tmp_buffer) + h0 = F.randn((g.number_of_nodes(), 5)) h1 = edge_conv(g, h0) assert h1.shape == (g.number_of_nodes(), out_dim) @@ -811,6 +859,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads): dotgat = nn.DotGatConv(5, out_dim, num_heads) feat = F.randn((g.number_of_nodes(), 5)) dotgat = dotgat.to(ctx) + + # test pickle + th.save(dotgat, tmp_buffer) + h = dotgat(g, feat) assert h.shape == (g.number_of_nodes(), num_heads, out_dim) _, a = dotgat(g, feat, get_attention=True) @@ -919,6 +971,7 @@ def test_atomic_conv(g, idtype): dist = F.randn((g.number_of_edges(), 1)) h = aconv(g, feat, dist) + # current we only do shape check assert h.shape[-1] == 4 @@ -968,6 +1021,10 @@ def test_hetero_conv(agg, idtype): 'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)}, agg) conv = conv.to(F.ctx()) + + # test pickle + th.save(conv, tmp_buffer) + uf = F.randn((4, 2)) gf = F.randn((4, 4)) sf = F.randn((2, 3))