Skip to content

Commit

Permalink
Fix citation data loading (dmlc#2421)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
classicsong and Ubuntu authored Dec 14, 2020
1 parent 0ceb675 commit 9d7bf4e
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/dgl/data/citation_graph.py
Original file line number Diff line number Diff line change
@@ -171,13 +171,15 @@ def load(self):
graphs, _ = load_graphs(str(graph_path))

info = load_info(str(info_path))
self._g = graphs[0]
graph = graphs[0]
self._g = graph
# for compatability
graph = graph.clone()
graph.pop('train_mask')
graph.pop('val_mask')
graph.pop('test_mask')
graph.pop('feat')
graph.pop('label')
graph.ndata.pop('train_mask')
graph.ndata.pop('val_mask')
graph.ndata.pop('test_mask')
graph.ndata.pop('feat')
graph.ndata.pop('label')
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)

@@ -328,7 +330,7 @@ class CoraGraphDataset(CitationGraphDataset):
- Number of Classes: 7
- Label split:
- Train: 140
- Train: 140
- Valid: 500
- Test: 1000
2 changes: 2 additions & 0 deletions python/dgl/data/dgl_dataset.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from __future__ import absolute_import

import os, sys, hashlib
import traceback
import abc
from .utils import download, extract_archive, get_download_dir, makedirs
from ..utils import retry_method_with_fix
@@ -170,6 +171,7 @@ def _load(self):
except:
load_flag = False
if self.verbose:
print(traceback.format_exc())
print('Loading from cache failed, re-processing.')

if not load_flag:

0 comments on commit 9d7bf4e

Please sign in to comment.