Skip to content

Commit

Permalink
fix relgraphconv bug (dmlc#3256)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Aug 23, 2021
1 parent 8341244 commit b4cd60a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
16 changes: 8 additions & 8 deletions examples/pytorch/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,43 @@ pip install requests torch rdflib pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4

### Entity Classification
AIFB: accuracy 92.59% (3 runs, DGL), 95.83% (paper)
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper)
```
python3 entity_classify.py -d aifb --testing --gpu 0
```

MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper)
MUTAG: accuracy 70.59% (3 runs, DGL), 73.23% (paper)
```
python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0
```

BGS: accuracy 89.66% (3 runs, DGL), 83.10% (paper)
BGS: accuracy 93.10% (3 runs, DGL), 83.10% (paper)
```
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```

AM: accuracy 89.73% (3 runs, DGL), 89.29% (paper)
AM: accuracy 89.22% (3 runs, DGL), 89.29% (paper)
```
python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --testing
```

### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 90.56%, best 94.44% (DGL)
AIFB: accuracy avg(5 runs) 90.00%, best 94.44% (DGL)
```
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout='20,20' --batch-size 128
```

MUTAG: accuracy avg(10 runs) 69.41%, best 76.47% (DGL)
MUTAG: accuracy avg(10 runs) 62.94%, best 72.06% (DGL)
```
python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --dgl-sparse --n-epochs 20 --sparse-lr 0.01 --dropout 0.5
```

BGS: accuracy avg(5 runs) 85.52%, best 93.10% (DGL)
BGS: accuracy avg(5 runs) 78.62%, best 86.21% (DGL)
```
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --dgl-sparse --lr 0.01 --sparse-lr 0.05 --dropout 0.3
```

AM: accuracy avg(5 runs) 88.59%, best 88.89% (DGL)
AM: accuracy avg(5 runs) 87.37%, best 89.9% (DGL)
```
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dgl-sparse --lr 0.01 --sparse-lr 0.02 --dropout 0.7
```
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/nn/pytorch/conv/relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,9 @@ def basis_message_func(self, edges, etypes):
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
idim = weight.shape[1]
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * weight.shape[1] + h
flatidx = etypes * idim + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
Expand Down

0 comments on commit b4cd60a

Please sign in to comment.