Skip to content

Commit

Permalink
Fix GINDT and dmlc#2087 (dmlc#2103)
Browse files Browse the repository at this point in the history
* fix gindt

* ff

* fix

* minor fix

* fix
  • Loading branch information
VoVAllen authored Aug 26, 2020
1 parent 628d9fc commit 51ba662
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
51 changes: 31 additions & 20 deletions python/dgl/data/gindt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph


class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN)
Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
Expand Down Expand Up @@ -232,32 +233,36 @@ def process(self):
if self.degree_as_nlabel:
if self.verbose:
print('generate node features by node degree...')
nlabel_set = set([])
for g in self.graphs:
# actually this label shouldn't be updated
# in case users want to keep it
# but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees()
# extracting unique node labels
nlabel_set = nlabel_set.union(set([F.as_scalar(nl) for nl in g.ndata['label']]))

nlabel_set = list(nlabel_set)
# in case the labels/degrees are not continuous number
self.ndegree_dict = {
# in case the labels/degrees are not continuous number
nlabel_set = set([])
for g in self.graphs:
nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']]))
nlabel_set = list(nlabel_set)
if len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
label2idx = self.nlabel_dict
else:
label2idx = {
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
label2idx = self.ndegree_dict
# generate node attr by node label
else:
if self.verbose:
print('generate node features by node label...')
label2idx = self.nlabel_dict

for g in self.graphs:
g.ndata['attr'] = F.tensor(np.zeros((
g.number_of_nodes(), len(label2idx))))
g.ndata['attr'][range(g.number_of_nodes()), [label2idx[F.as_scalar(F.reshape(nl, (1,)))] for nl in g.ndata['label']]] = 1
attr = np.zeros((
g.number_of_nodes(), len(label2idx)))
attr[range(g.number_of_nodes()), [label2idx[nl]
for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1
g.ndata['attr'] = F.tensor(attr)

# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
Expand Down Expand Up @@ -288,8 +293,10 @@ def process(self):
self.nlabel_dict, self.ndegree_dict))

def save(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels}
info_dict = {'N': self.N,
'n': self.n,
Expand All @@ -308,8 +315,10 @@ def save(self):
save_info(str(info_path), info_dict)

def load(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))

Expand All @@ -331,8 +340,10 @@ def load(self):
self.degree_as_nlabel = info_dict['degree_as_nlabel']

def has_cache(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
6 changes: 3 additions & 3 deletions python/dgl/data/graph_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def save_graphs(filename, g_list, labels=None):
load_graphs
"""
# if it is local file, do some sanity check
if filename.startswith('s3://') is False:
if not filename.startswith('s3://'):
if os.path.isdir(filename):
raise DGLError("Filename {} is an existing directory.".format(filename))
f_path, _ = os.path.split(filename)
if not os.path.exists(f_path):
f_path = os.path.dirname(filename)
if f_path and not os.path.exists(f_path):
os.makedirs(f_path)

g_sample = g_list[0] if isinstance(g_list, list) else g_list
Expand Down

0 comments on commit 51ba662

Please sign in to comment.