Skip to content

Commit

Permalink
fix bugs (dmlc#3008)
Browse files Browse the repository at this point in the history
Co-authored-by: zhjwy9343 <[email protected]>
  • Loading branch information
hengruizhang98 and zhjwy9343 authored Jun 16, 2021
1 parent 36c6c64 commit 18eaad1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/pytorch/mvgrl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang9
--wd2 float Weight decay of linear classifier. Default is 0.0.
--epsilon float Edge mask threshold. Default is 0.01.
--hid_dim int Embedding dimension. Default is 512.
--sample_size int Subgraph size. Default is 2000.
```

## How to run examples
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/mvgrl/node/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@
accs.append(acc * 100)

accs = th.stack(accs)
print(accs.mean().item(), accs.std().item())
print(accs.mean().item(), accs.std().item())
5 changes: 3 additions & 2 deletions examples/pytorch/mvgrl/node/main_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.')
parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.')
parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.')
parser.add_argument("--sample_size", type=int, default=2000, help='Subgraph size.')

args = parser.parse_args()

Expand Down Expand Up @@ -54,7 +55,7 @@

n_node = graph.number_of_nodes()

sample_size = 2000
sample_size = args.sample_size

lbl1 = th.ones(sample_size * 2)
lbl2 = th.zeros(sample_size * 2)
Expand Down Expand Up @@ -153,4 +154,4 @@
accs.append(acc * 100)

accs = th.stack(accs)
print(accs.mean().item(), accs.std().item())
print(accs.mean().item(), accs.std().item())

0 comments on commit 18eaad1

Please sign in to comment.