Skip to content

Commit

Permalink
[Dataset] RedditDataset change data.train_mask to numpy array (dmlc#1961
Browse files Browse the repository at this point in the history
)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c.

* Update reddit.py

Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
HuXiangkun and classicsong authored Aug 7, 2020
1 parent 35c9473 commit 5d5436b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/dgl/data/reddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ class RedditDataset(DGLBuiltinDataset):
Graph of the dataset
num_labels : int
Number of classes for each node
train_mask: Tensor
train_mask: numpy.ndarray
Mask of training nodes
val_mask: Tensor
val_mask: numpy.ndarray
Mask of validation nodes
test_mask: Tensor
test_mask: numpy.ndarray
Mask of test nodes
features : Tensor
Node features
Expand Down Expand Up @@ -202,17 +202,17 @@ def graph(self):
@property
def train_mask(self):
deprecate_property('dataset.train_mask', 'graph.ndata[\'train_mask\']')
return self._graph.ndata['train_mask']
return F.asnumpy(self._graph.ndata['train_mask'])

@property
def val_mask(self):
deprecate_property('dataset.val_mask', 'graph.ndata[\'val_mask\']')
return self._graph.ndata['val_mask']
return F.asnumpy(self._graph.ndata['val_mask'])

@property
def test_mask(self):
deprecate_property('dataset.test_mask', 'graph.ndata[\'test_mask\']')
return self._graph.ndata['test_mask']
return F.asnumpy(self._graph.ndata['test_mask'])

@property
def features(self):
Expand Down

0 comments on commit 5d5436b

Please sign in to comment.