Skip to content

Commit c2c2e42

Browse files
committedOct 15, 2020
Added Reddit file + download option after Reddit broken link in PyG
1 parent c5d2d10 commit c2c2e42

File tree

5 files changed

+139
-56
lines changed

5 files changed

+139
-56
lines changed
 

‎gcn_distr.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch.distributed as dist
99

1010
from torch_geometric.data import Data, Dataset
11-
from torch_geometric.datasets import Planetoid, PPI, Reddit
11+
from torch_geometric.datasets import Planetoid, PPI
12+
from reddit import Reddit
1213
from torch_geometric.nn import GCNConv, ChebConv # noqa
1314
from torch_geometric.utils import add_remaining_self_loops, to_dense_adj, dense_to_sparse, to_scipy_sparse_matrix
1415
import torch_geometric.transforms as T
@@ -60,6 +61,7 @@
6061
acc_per_rank = 0
6162
run_count = 0
6263
run = 0
64+
download = False
6365

6466
def start_time(group, rank, subset=False, src=None):
6567
global barrier_time
@@ -675,21 +677,22 @@ def main():
675677
print(socket.gethostname())
676678
seed = 0
677679

678-
mp.set_start_method('spawn', force=True)
679-
outputs = None
680-
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
681-
dist.init_process_group(backend='nccl')
682-
rank = dist.get_rank()
683-
size = dist.get_world_size()
684-
print("Processes: " + str(size))
685-
686-
# device = torch.device('cpu')
687-
devid = rank_to_devid(rank, acc_per_rank)
688-
device = torch.device('cuda:{}'.format(devid))
689-
torch.cuda.set_device(device)
690-
curr_devid = torch.cuda.current_device()
691-
# print(f"curr_devid: {curr_devid}", flush=True)
692-
devcount = torch.cuda.device_count()
680+
if not download:
681+
mp.set_start_method('spawn', force=True)
682+
outputs = None
683+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
684+
dist.init_process_group(backend='nccl')
685+
rank = dist.get_rank()
686+
size = dist.get_world_size()
687+
print("Processes: " + str(size))
688+
689+
# device = torch.device('cpu')
690+
devid = rank_to_devid(rank, acc_per_rank)
691+
device = torch.device('cuda:{}'.format(devid))
692+
torch.cuda.set_device(device)
693+
curr_devid = torch.cuda.current_device()
694+
# print(f"curr_devid: {curr_devid}", flush=True)
695+
devcount = torch.cuda.device_count()
693696

694697
if graphname == "Cora":
695698
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
@@ -755,6 +758,9 @@ def main():
755758
inputs.requires_grad = True
756759
data.y = data.y.to(device)
757760

761+
if download:
762+
exit()
763+
758764
if normalization:
759765
adj_matrix, _ = add_remaining_self_loops(edge_index, num_nodes=inputs.size(0))
760766
else:
@@ -778,6 +784,7 @@ def main():
778784
parser.add_argument("--normalization", type=str)
779785
parser.add_argument("--activations", type=str)
780786
parser.add_argument("--accuracy", type=str)
787+
parser.add_argument("--download", type=bool)
781788

782789
args = parser.parse_args()
783790
print(args)
@@ -792,10 +799,12 @@ def main():
792799
normalization = args.normalization == "True"
793800
activations = args.activations == "True"
794801
accuracy = args.accuracy == "True"
802+
download = args.download
795803

796-
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
797-
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer} {run_count}")
798-
exit()
804+
if not download:
805+
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
806+
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer} {run_count}")
807+
exit()
799808

800809
print(f"Arguments: epochs: {epochs} graph: {graphname} timing: {timing} mid: {mid_layer} norm: {normalization} act: {activations} acc: {accuracy} runs: {run_count}")
801810

‎gcn_distr_15d.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch.distributed as dist
99

1010
from torch_geometric.data import Data, Dataset
11-
from torch_geometric.datasets import Planetoid, PPI, Reddit
11+
from torch_geometric.datasets import Planetoid, PPI
12+
from reddit import Reddit
1213
from torch_geometric.nn import GCNConv, ChebConv # noqa
1314
from torch_geometric.utils import add_remaining_self_loops, to_dense_adj, dense_to_sparse, to_scipy_sparse_matrix
1415
import torch_geometric.transforms as T
@@ -61,6 +62,7 @@
6162
run_count = 0
6263
run = 0
6364
replication = 0
65+
download = False
6466

6567
def start_time(group, rank, subset=False, src=None):
6668
global barrier_time
@@ -715,22 +717,23 @@ def main():
715717
print(socket.gethostname())
716718
seed = 0
717719

718-
mp.set_start_method('spawn', force=True)
719-
outputs = None
720-
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
720+
if not download:
721+
mp.set_start_method('spawn', force=True)
722+
outputs = None
723+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
721724

722-
dist.init_process_group(backend='nccl')
723-
rank = dist.get_rank()
724-
size = dist.get_world_size()
725-
print("Processes: " + str(size))
725+
dist.init_process_group(backend='nccl')
726+
rank = dist.get_rank()
727+
size = dist.get_world_size()
728+
print("Processes: " + str(size))
726729

727-
# device = torch.device('cpu')
728-
devid = rank_to_devid(rank, acc_per_rank)
729-
device = torch.device('cuda:{}'.format(devid))
730-
torch.cuda.set_device(device)
731-
curr_devid = torch.cuda.current_device()
732-
# print(f"curr_devid: {curr_devid}", flush=True)
733-
devcount = torch.cuda.device_count()
730+
# device = torch.device('cpu')
731+
devid = rank_to_devid(rank, acc_per_rank)
732+
device = torch.device('cuda:{}'.format(devid))
733+
torch.cuda.set_device(device)
734+
curr_devid = torch.cuda.current_device()
735+
# print(f"curr_devid: {curr_devid}", flush=True)
736+
devcount = torch.cuda.device_count()
734737

735738
if graphname == "Cora":
736739
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
@@ -798,6 +801,9 @@ def main():
798801
inputs.requires_grad = True
799802
data.y = data.y.to(device)
800803

804+
if download:
805+
exit()
806+
801807
if normalization:
802808
adj_matrix, _ = add_remaining_self_loops(edge_index, num_nodes=inputs.size(0))
803809
else:
@@ -822,6 +828,7 @@ def main():
822828
parser.add_argument("--normalization", type=str)
823829
parser.add_argument("--activations", type=str)
824830
parser.add_argument("--accuracy", type=str)
831+
parser.add_argument("--download", type=bool)
825832

826833
args = parser.parse_args()
827834
print(args)
@@ -837,10 +844,12 @@ def main():
837844
activations = args.activations == "True"
838845
accuracy = args.accuracy == "True"
839846
replication = args.replication
847+
download = args.download
840848

841-
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
842-
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer} {run_count}")
843-
exit()
849+
if not download:
850+
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
851+
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer} {run_count}")
852+
exit()
844853

845854
print(f"Arguments: epochs: {epochs} graph: {graphname} timing: {timing} mid: {mid_layer} norm: {normalization} act: {activations} acc: {accuracy} runs: {run_count} rep: {replication}")
846855

‎gcn_distr_2d.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.distributed as dist
1010

1111
from torch_geometric.data import Data, Dataset
12-
from torch_geometric.datasets import Planetoid, PPI, Reddit
12+
from torch_geometric.datasets import Planetoid, PPI
13+
from reddit import Reddit
1314
from torch_geometric.nn import GCNConv, ChebConv # noqa
1415
from torch_geometric.utils import (
1516
add_remaining_self_loops,
@@ -90,6 +91,7 @@
9091
no_occur_val = 42.1234
9192
run_count = 0
9293
run = 0
94+
download = False
9395

9496
def sync_and_sleep(rank, device):
9597
torch.cuda.synchronize(device=device)
@@ -1450,6 +1452,9 @@ def main(P, correctness_check, acc_per_rank):
14501452
data.y = torch.rand(n).uniform_(0, num_classes - 1)
14511453
data.train_mask = torch.ones(n).long()
14521454

1455+
if download:
1456+
exit()
1457+
14531458
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
14541459
dist.init_process_group(backend='nccl')
14551460
# dist.init_process_group('gloo', init_method='env://')
@@ -1523,6 +1528,7 @@ def main(P, correctness_check, acc_per_rank):
15231528
parser.add_argument("--normalization", type=str)
15241529
parser.add_argument("--activations", type=str)
15251530
parser.add_argument("--accuracy", type=str)
1531+
parser.add_argument("--download", type=bool)
15261532
args = parser.parse_args()
15271533
print(args)
15281534
P = args.processes
@@ -1547,10 +1553,12 @@ def main(P, correctness_check, acc_per_rank):
15471553
normalization = args.normalization == "True"
15481554
activations = args.activations == "True"
15491555
accuracy = args.accuracy == "True"
1556+
download = args.download
15501557

1551-
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
1552-
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer}")
1553-
exit()
1558+
if not download:
1559+
if (epochs is None) or (graphname is None) or (timing is None) or (mid_layer is None) or (run_count is None):
1560+
print(f"Error: missing argument {epochs} {graphname} {timing} {mid_layer}")
1561+
exit()
15541562

15551563
print(f"Arguments: epochs: {epochs} graph: {graphname} timing: {timing} mid: {mid_layer} norm: {normalization} act: {activations} acc: {accuracy}")
15561564

‎gcn_distr_3d.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.distributed as dist
1010

1111
from torch_geometric.data import Data, Dataset
12-
from torch_geometric.datasets import Planetoid, PPI, Reddit
12+
from torch_geometric.datasets import Planetoid, PPI
13+
from reddit import Reddit
1314
from torch_geometric.nn import GCNConv, ChebConv # noqa
1415
from torch_geometric.utils import (
1516
add_remaining_self_loops,
@@ -1772,32 +1773,18 @@ def main(P, correctness_check, acc_per_rank):
17721773

17731774
if __name__ == '__main__':
17741775
parser = argparse.ArgumentParser()
1775-
parser.add_argument('--processes', metavar='P', type=int,
1776-
help='Number of processes')
1777-
parser.add_argument('--correctness', metavar='C', type=str,
1778-
help='Run correctness check')
17791776
parser.add_argument("--accperrank", type=int)
17801777
parser.add_argument("--epochs", type=int)
17811778
parser.add_argument("--graphname", type=str)
17821779
parser.add_argument("--timing", type=str)
17831780
parser.add_argument("--midlayer", type=int)
1784-
parser.add_argument("--local_rank", type=int)
17851781
args = parser.parse_args()
17861782
print(args)
1787-
P = args.processes
1788-
correctness_check = args.correctness
1789-
if P is None:
1790-
P = 1
17911783

17921784
acc_per_rank = args.accperrank
17931785
if acc_per_rank is None:
17941786
acc_per_rank = 1
17951787

1796-
if correctness_check is None or correctness_check == "nocheck":
1797-
correctness_check = False
1798-
else:
1799-
correctness_check = True
1800-
18011788
epochs = args.epochs
18021789
graphname = args.graphname
18031790
timing = args.timing == "True"

‎reddit.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import os.path as osp
3+
4+
import torch
5+
import numpy as np
6+
import scipy.sparse as sp
7+
from torch_sparse import coalesce
8+
from torch_geometric.data import (InMemoryDataset, Data, download_url,
9+
extract_zip)
10+
11+
12+
class Reddit(InMemoryDataset):
13+
r"""The Reddit dataset from the `"Inductive Representation Learning on
14+
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, containing
15+
Reddit posts belonging to different communities.
16+
17+
Args:
18+
root (string): Root directory where the dataset should be saved.
19+
transform (callable, optional): A function/transform that takes in an
20+
:obj:`torch_geometric.data.Data` object and returns a transformed
21+
version. The data object will be transformed before every access.
22+
(default: :obj:`None`)
23+
pre_transform (callable, optional): A function/transform that takes in
24+
an :obj:`torch_geometric.data.Data` object and returns a
25+
transformed version. The data object will be transformed before
26+
being saved to disk. (default: :obj:`None`)
27+
"""
28+
29+
url = 'https://data.dgl.ai/dataset/reddit.zip'
30+
31+
def __init__(self, root, transform=None, pre_transform=None):
32+
super(Reddit, self).__init__(root, transform, pre_transform)
33+
self.data, self.slices = torch.load(self.processed_paths[0])
34+
35+
@property
36+
def raw_file_names(self):
37+
return ['reddit_data.npz', 'reddit_graph.npz']
38+
39+
@property
40+
def processed_file_names(self):
41+
return 'data.pt'
42+
43+
def download(self):
44+
path = download_url(self.url, self.raw_dir)
45+
extract_zip(path, self.raw_dir)
46+
os.unlink(path)
47+
48+
def process(self):
49+
data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))
50+
x = torch.from_numpy(data['feature']).to(torch.float)
51+
y = torch.from_numpy(data['label']).to(torch.long)
52+
split = torch.from_numpy(data['node_types'])
53+
54+
adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz'))
55+
row = torch.from_numpy(adj.row).to(torch.long)
56+
col = torch.from_numpy(adj.col).to(torch.long)
57+
edge_index = torch.stack([row, col], dim=0)
58+
edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
59+
60+
data = Data(x=x, edge_index=edge_index, y=y)
61+
data.train_mask = split == 1
62+
data.val_mask = split == 2
63+
data.test_mask = split == 3
64+
65+
data = data if self.pre_transform is None else self.pre_transform(data)
66+
67+
torch.save(self.collate([data]), self.processed_paths[0])
68+
69+
def __repr__(self):
70+
return '{}()'.format(self.__class__.__name__)

0 commit comments

Comments
 (0)
Please sign in to comment.