Skip to content

Commit

Permalink
SBM (PATTERN and CLUSTER) accuracy fix (rampasek#12)
Browse files Browse the repository at this point in the history
* Accuracy eval fn for PATTERN and CLUSTER datasets

* updated PATTERN and CLUSTER results
  • Loading branch information
rampasek authored Sep 9, 2022
1 parent cbaed0b commit 6305368
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 27 deletions.
2 changes: 1 addition & 1 deletion configs/GPS/cluster-GPS-ESLapPE.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: CLUSTER
Expand Down
4 changes: 2 additions & 2 deletions configs/GPS/cluster-GPS.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: CLUSTER
Expand Down Expand Up @@ -59,7 +59,7 @@ optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 1e-5
base_lr: 0.001
base_lr: 0.0005
max_epoch: 100
scheduler: cosine_with_warmup
num_warmup_epochs: 5
2 changes: 1 addition & 1 deletion configs/GPS/pattern-GPS-ESLapPE.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: PATTERN
Expand Down
16 changes: 3 additions & 13 deletions configs/GPS/pattern-GPS.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: PATTERN
Expand All @@ -21,7 +21,7 @@ posenc_LapPE:
eigen:
laplacian_norm: none
eigvec_norm: L2
max_freqs: 10
max_freqs: 16
model: DeepSet
dim_pe: 16
layers: 2
Expand Down Expand Up @@ -55,21 +55,11 @@ gnn:
dropout: 0.0
agg: mean
normalize_adj: False
#optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.0
# base_lr: 0.0005
# max_epoch: 1000
# scheduler: reduce_on_plateau
# reduce_factor: 0.5
# schedule_patience: 10
# min_lr: 1e-5
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 1e-5
base_lr: 0.001
base_lr: 0.0005
max_epoch: 100
scheduler: cosine_with_warmup
num_warmup_epochs: 5
2 changes: 1 addition & 1 deletion configs/SAN/cluster-SAN.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: gtblueprint
Expand Down
2 changes: 1 addition & 1 deletion configs/SAN/pattern-SAN.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: results
metric_best: accuracy
metric_best: accuracy-SBM
wandb:
use: True
project: gtblueprint
Expand Down
Binary file modified final-results.zip
Binary file not shown.
34 changes: 32 additions & 2 deletions graphgps/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from scipy.stats import stats
from sklearn.metrics import accuracy_score, precision_score, recall_score, \
f1_score, roc_auc_score, mean_absolute_error, mean_squared_error
f1_score, roc_auc_score, mean_absolute_error, mean_squared_error, \
confusion_matrix
from sklearn.metrics import r2_score
from torch_geometric.graphgym import get_current_gpu_usage
from torch_geometric.graphgym.config import cfg
Expand All @@ -17,6 +18,29 @@
from graphgps.metric_wrapper import MetricWrapper


def accuracy_SBM(targets, pred_int):
"""Accuracy eval for Benchmarking GNN's PATTERN and CLUSTER datasets.
https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/train/metrics.py#L34
"""
S = targets
C = pred_int
CM = confusion_matrix(S, C).astype(np.float32)
nb_classes = CM.shape[0]
targets = targets.cpu().detach().numpy()
nb_non_empty_classes = 0
pr_classes = np.zeros(nb_classes)
for r in range(nb_classes):
cluster = np.where(targets == r)[0]
if cluster.shape[0] != 0:
pr_classes[r] = CM[r, r] / float(cluster.shape[0])
if CM[r, r] > 0:
nb_non_empty_classes += 1
else:
pr_classes[r] = 0.0
acc = np.sum(pr_classes) / float(nb_classes)
return acc


class CustomLogger(Logger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -59,23 +83,29 @@ def classification_binary(self):
auroc_score = 0.

reformat = lambda x: round(float(x), cfg.round)
return {
res = {
'accuracy': reformat(accuracy_score(true, pred_int)),
'precision': reformat(precision_score(true, pred_int)),
'recall': reformat(recall_score(true, pred_int)),
'f1': reformat(f1_score(true, pred_int)),
'auc': reformat(auroc_score),
}
if cfg.metric_best == 'accuracy-SBM':
res['accuracy-SBM'] = reformat(accuracy_SBM(true, pred_int))
return res

def classification_multi(self):
true, pred_score = torch.cat(self._true), torch.cat(self._pred)
pred_int = self._get_pred_int(pred_score)
reformat = lambda x: round(float(x), cfg.round)

res = {
'accuracy': reformat(accuracy_score(true, pred_int)),
'f1': reformat(f1_score(true, pred_int,
average='macro', zero_division=0)),
}
if cfg.metric_best == 'accuracy-SBM':
res['accuracy-SBM'] = reformat(accuracy_SBM(true, pred_int))
if true.shape[0] < 1e7:
# AUROC computation for very large datasets runs out of memory.
# TorchMetrics AUROC on GPU is much faster than sklearn for large ds
Expand Down
12 changes: 10 additions & 2 deletions run/run_experiments.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ run_repeats ${DATASET} GPS "name_tag GPSwLapPE.GatedGCN+Trf.10run"


DATASET="pattern"
run_repeats ${DATASET} GPS "name_tag GPSwLapPE.GatedGCN+Trf.10run"
run_repeats ${DATASET} GPS "name_tag GPSwLapPE.GatedGCN+Trf.eigv16.lr0005"


DATASET="cluster"
run_repeats ${DATASET} GPS "name_tag GPSwLapPE.GatedGCN+Trf.10run.drp01.wd-5"
run_repeats ${DATASET} GPS "name_tag GPSwLapPE.GatedGCN+Trf.lr0005.10run"


DATASET="ogbg-molhiv"
Expand All @@ -92,3 +92,11 @@ run_repeats ${DATASET} GPSmedium+RWSE "name_tag GPSwRWSE.medium.lyr10.dim384.hea
DATASET="malnettiny"
run_repeats ${DATASET} GPS-noPE "name_tag GPS-noPE.GatedGCN+Perf.lyr5.dim64.10runs"
run_repeats ${DATASET} GPS-noPE "name_tag GPS-noPE.GatedGCN+Trf.lyr5.dim64.bs4.bacc4.10run train.batch_size 4 optim.batch_accumulation 4 gt.layer_type CustomGatedGCN+Transformer"


################################################################################
##### extra
################################################################################
cfg_dir="configs/SAN"
DATASET="pattern"
#run_repeats ${DATASET} SAN "name_tag SAN.LapPE.10run"
2 changes: 1 addition & 1 deletion tests/configs/graph/cluster-SAN.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: tests/results
metric_best: accuracy
metric_best: accuracy-SBM
dataset:
format: PyG-GNNBenchmarkDataset
name: CLUSTER
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/graph/cluster.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: tests/results
metric_best: accuracy
metric_best: accuracy-SBM
dataset:
format: PyG-GNNBenchmarkDataset
name: CLUSTER
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/graph/pattern-SAN.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
out_dir: tests/results
tensorboard_each_run: True # Log to Tensorboard each run
metric_best: accuracy
metric_best: accuracy-SBM
dataset:
format: PyG-GNNBenchmarkDataset
name: PATTERN
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/graph/pattern.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
out_dir: tests/results
metric_best: accuracy
metric_best: accuracy-SBM
dataset:
format: PyG-GNNBenchmarkDataset
name: PATTERN
Expand Down

0 comments on commit 6305368

Please sign in to comment.