Skip to content

Commit

Permalink
Update (dmlc#1319)
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili authored Mar 5, 2020
1 parent 066d290 commit 4ec8f20
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/gin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Graph Isomorphism Network (GIN)

Dependencies
------------
- PyTorch 1.0.1+
- PyTorch 1.1.0+
- sklearn
- tqdm

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/gin/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def collate(samples):
for g in graphs:
# deal with node feats
for key in g.node_attr_schemes().keys():
g.ndata[key] = torch.from_numpy(g.ndata[key]).float()
g.ndata[key] = g.ndata[key].float()
# no edge feats
batched_graph = dgl.batch(graphs)
labels = torch.tensor(labels)
Expand Down
8 changes: 4 additions & 4 deletions examples/pytorch/gin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def eval_net(args, net, dataloader, criterion):
def main(args):

# set up seeds, args.seed supported
torch.manual_seed(seed=0)
np.random.seed(seed=0)
torch.manual_seed(seed=args.seed)
np.random.seed(seed=args.seed)

is_cuda = not args.disable_cuda and torch.cuda.is_available()

if is_cuda:
args.device = torch.device("cuda:" + str(args.device))
torch.cuda.manual_seed_all(seed=0)
torch.cuda.manual_seed_all(seed=args.seed)
else:
args.device = torch.device("cpu")

Expand Down Expand Up @@ -109,9 +109,9 @@ def main(args):
lrbar = tqdm(range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)

for epoch, _, _ in zip(tbar, vbar, lrbar):
scheduler.step()

train(args, model, trainloader, optimizer, criterion, epoch)
scheduler.step()

train_loss, train_acc = eval_net(
args, model, trainloader, criterion)
Expand Down
7 changes: 1 addition & 6 deletions examples/pytorch/gin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _parse(self):
# dataset
self.parser.add_argument(
'--dataset', type=str, default="MUTAG",
choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI'],
help='name of dataset (default: MUTAG)')
self.parser.add_argument(
'--batch_size', type=int, default=32,
Expand All @@ -39,9 +40,6 @@ def _parse(self):
help='which gpu device to use (default: 0)')

# net
self.parser.add_argument(
'--net', type=str, default="gin",
help='gnn net (default: gin)')
self.parser.add_argument(
'--num_layers', type=int, default=5,
help='number of layers (default: 5)')
Expand All @@ -64,9 +62,6 @@ def _parse(self):
self.parser.add_argument(
'--learn_eps', action="store_true",
help='learn the epsilon weighting')
self.parser.add_argument(
'--degree_as_tag', action="store_true",
help='take the degree of nodes as input feature')

# learning
self.parser.add_argument(
Expand Down

0 comments on commit 4ec8f20

Please sign in to comment.