Skip to content

Commit

Permalink
ICML; privacy module is coming soon..
Browse files Browse the repository at this point in the history
  • Loading branch information
Minji Yoon committed Jun 1, 2023
1 parent 8f11199 commit 7ee9198
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 23 deletions.
2 changes: 1 addition & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_args():
help='GPT, XLNet, or Bayes')
parser.add_argument('--gpt_softmax_temperature', type=float, default=1.,
help='Temperature used to sample')
parser.add_argument('--gpt_epochs', type=int, default=5,
parser.add_argument('--gpt_epochs', type=int, default=50,
help='Number of epochs to train.')
parser.add_argument('--gpt_batch_size', type=int, default=128,
help='Size of batch.')
Expand Down
2 changes: 1 addition & 1 deletion generator/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def kmeans(feats, cluster_num, cluster_size, cluster_sample_num):
else:
x = feats

pca = PCA(n_components=128)
pca = PCA(n_components=min(feats.shape[1], 128))
x_pca = pca.fit_transform(x)

clf = KMeansConstrained(n_clusters=cluster_num, size_min=cluster_size, init='random', n_init=1, max_iter=8)
Expand Down
8 changes: 6 additions & 2 deletions generator/gpt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class Dataset(torch.utils.data.Dataset):
def __init__(self, args, adjs, cluster_ids, labels, ids):
self.adjs = adjs
self.adjs_list = isinstance(adjs, list)
self.cluster_ids = cluster_ids
self.labels = labels
self.ids = ids
Expand Down Expand Up @@ -39,7 +40,10 @@ def __getitem__(self, index):
if target_id == org_empty_id:
source_ids = []
else:
source_ids = np.nonzero(self.adjs[target_id])[0].tolist()
if self.adjs_list:
source_ids = self.adjs[target_id]
else:
source_ids = np.nonzero(self.adjs[target_id])[0].tolist()
# Sample fixed number of neighbors
if len(source_ids) == 0:
sampled_ids = self.sample_num * [org_empty_id]
Expand All @@ -49,7 +53,7 @@ def __getitem__(self, index):
sampled_ids = np.random.choice(source_ids, self.sample_num, replace = False).tolist()

if self.noise_num > 0:
perm = np.random.permutation(self.adjs.shape[0])[:self.noise_num]
perm = np.random.permutation(len(self.adjs))[:self.noise_num]
sampled_ids = np.concatenate((sampled_ids, perm), axis=0)

sampled_cluster_ids.extend(self.cluster_ids[sampled_ids])
Expand Down
2 changes: 1 addition & 1 deletion generator/gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate(args, model, labels, ids, split):
result = torch.cat(result, dim = 0)
print("[GPT] name: {}, split: {}, generation time: {:.3f}".format(args.gpt_train_name, split, perf_counter() - start_time))

return result
return result.cpu()


def run(args, graphs, feats, labels, ids):
Expand Down
7 changes: 1 addition & 6 deletions generator/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,6 @@ def __init__(self, config):
# remove start_id
self.head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)

#mask = torch.ones(1, config.block_size, config.vocab_size - 1)
#mask[:, 0, -1] = 0
#self.register_buffer("output_mask", mask)

self.apply(self._init_weights)
self.criterion = nn.CrossEntropyLoss()

Expand Down Expand Up @@ -236,11 +232,10 @@ def forward(self, idx, classes, targets=None):

x = self.drop(token_embeddings + position_embeddings[:, :t])
q = self.drop((self.query_emb + position_embeddings[:, 1:(t+1)] + class_embeddings).expand_as(x))
_, q = self.blocks((x,q))
_, q = self.blocks((x, q))
q = self.ln_f(q)
logits = self.head(q)
logits[:, 0, -1] = float('-inf')
#logits = logits.masked_fill(self.output_mask[:, :t , : ] == 0, float('-inf'))

loss = None
if targets is not None:
Expand Down
4 changes: 2 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
DATASETS=("cora")
DATASETS=("cora" "citeseer")
data_length=${#DATASETS[@]}

# Experiment 1: effects of noise to aggregation strategies
NOISES=(0)
NOISES=(0 2 4)
noise_length=${#NOISES[@]}

for ((i=0;i<$data_length;i++))
Expand Down
9 changes: 7 additions & 2 deletions task/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
class Dataset(torch.utils.data.Dataset):
def __init__(self, args, split, adjs, feats, labels, ids):
self.adjs = adjs
self.adjs_list = isinstance(adjs, list)
self.feats = feats
self.labels = labels
self.ids = ids[split]

self.node_num = feats.shape[0]
self.empty_id = feats.shape[0]
self.feats = np.concatenate((self.feats, np.zeros((1, feats.shape[1]))), axis=0)

Expand All @@ -32,7 +34,10 @@ def __getitem__(self, index):
if target_id == self.empty_id:
source_ids = []
else:
source_ids = np.nonzero(self.adjs[target_id])[0].tolist()
if self.adjs_list:
source_ids = self.adjs[target_id]
else:
source_ids = np.nonzero(self.adjs[target_id])[0].tolist()
# Sample fixed number of neighbors
if len(source_ids) == 0:
sampled_ids = self.sample_num * [self.empty_id]
Expand All @@ -42,7 +47,7 @@ def __getitem__(self, index):
sampled_ids = np.random.choice(source_ids, self.sample_num, replace = False).tolist()

if self.noise_num > 0:
perm = np.random.permutation(self.adjs.shape[0])[:self.noise_num]
perm = np.random.permutation(self.node_num)[:self.noise_num]
sampled_ids = np.concatenate((sampled_ids, perm), axis=0)

sampled_nodes.extend(sampled_ids)
Expand Down
35 changes: 32 additions & 3 deletions task/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import csv
import os.path as osp
import networkx as nx
import numpy as np
import random
import scipy.sparse as sp
import pandas as pd
from collections import defaultdict
from sklearn import metrics
from sklearn.preprocessing import normalize
Expand All @@ -11,6 +13,9 @@
import torch
import torch.nn.functional as F

from ogb.io.read_graph_pyg import read_graph_pyg
from torch_geometric.utils import to_undirected


def set_seed(seed):
random.seed(seed)
Expand All @@ -20,8 +25,8 @@ def set_seed(seed):

train_ratio = 0.4
val_ratio = 0.2
def split_ids(args, graph, labels):
node_ids = list(range(graph.shape[0]))
def split_ids(args, node_num):
node_ids = list(range(node_num))
random.shuffle(node_ids)

ids = {}
Expand All @@ -31,6 +36,30 @@ def split_ids(args, graph, labels):

return ids


def convert_to_edge_list(edge_index, X):
edge_list = []
sorted, indices = torch.sort(edge_index[1])
source_ids = edge_index[0][indices]
target_ids = edge_index[1][indices]

j = 0
for i in range(X.shape[0]):
neighbor_list = []
while j < target_ids.shape[0] and target_ids[j] == i:
neighbor_list.append(source_ids[j].item())
j += 1
edge_list.append(neighbor_list)

return edge_list


def normalize_features(features):
features = features - features.min()
features.div_(features.sum(dim=-1, keepdim=True).clamp_(min=1.))
return features


def load_ogbn(args):
master_file = args.data_dir + "/ogbn-master.csv"
master = pd.read_csv(master_file, index_col = 0)
Expand All @@ -43,7 +72,7 @@ def load_ogbn(args):

data_dir = args.data_dir + "/" + args.dataset + "/"
data = read_graph_pyg(data_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=binary)[0]
data.x = normalize_features(data.x)
#data.x = normalize_features(data.x)
node_feat = data.x.numpy()

data.edge_index = to_undirected(data.edge_index)
Expand Down
11 changes: 6 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main():

# Load the original graph datasets
adj, feat, label, feat_size, label_size = load_graph(args)
ids = split_ids(args, adj, label)
ids = split_ids(args, feat.shape[0])
args.feat_size = feat_size
args.label_size = label_size
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -76,18 +76,19 @@ def main():
## Check GNN performance on the generated dataset
start_time = perf_counter()
acc_mic, acc_mac = evaluate(args, gen_train_set, gen_val_set, gen_test_set)
acc_mic_list[0, : , t] = acc_mic
acc_mac_list[0, : , t] = acc_mac
acc_mic_list[1, : , t] = acc_mic
acc_mac_list[1, : , t] = acc_mac
print('Synthetic evaluation time: {:.3f}, acc: {}'.format(perf_counter() - start_time, acc_mic))

test_acc_avg = np.average(acc_mic_list, axis=2)
test_acc_std = np.std(acc_mic_list, axis=2)

print('\nTask: ' + args.task_name + ', Dataset: ' + args.dataset)
print('Task: ' + args.task_name + ', Dataset: ' + args.dataset)
for model_name in args.model_list:
print(model_name, end=', ')
print()
for model_id in range(len(args.model_list)):
print("\nORI: {:.2f} {:.3f}, GEN: {:.2f} {:.3f}".format(test_acc_avg[0][model_id], test_acc_std[0][model_id],\
print("ORG: {:.2f} {:.3f}, GEN: {:.2f} {:.3f}".format(test_acc_avg[0][model_id], test_acc_std[0][model_id],\
test_acc_avg[1][model_id], test_acc_std[1][model_id]))


Expand Down

0 comments on commit 7ee9198

Please sign in to comment.