Skip to content

Commit

Permalink
benchmark task
Browse files Browse the repository at this point in the history
  • Loading branch information
RexYing committed Apr 18, 2018
1 parent 1523141 commit 361ade7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 66 deletions.
55 changes: 12 additions & 43 deletions load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import networkx as nx
from graph_functions import *


def read_graphfile(datadir, dataname):
''' Read data from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets
'''
filename_graph_indic=datadir+dataname+"/"+dataname+"_graph_indicator.txt"
# index of graphs that a given node belongs to
graph_indic={}

with open(filename_graph_indic) as f:
i=1
for line in f:
line=line.strip("\n")
graph_indic[i]=int(line)
i+=1
i=1
for line in f:
line=line.strip("\n")
graph_indic[i]=int(line)
i+=1
filename_nodes=datadir+dataname+"/"+dataname+"_node_labels.txt"
node_labels=[]
with open(filename_nodes) as f:
Expand All @@ -28,10 +30,11 @@ def read_graphfile(datadir, dataname):
line=line.strip("\n")
graph_labels.append(int(line))

# graph index starts with 1 in file
adj_list={i:[] for i in range(1,len(graph_labels)+1)}
filename_adj="../data/"+dataname+"/"+dataname+"_A.txt"
gen_list=[]
index_graph={k:[] for k in range(1,1+len(graph_labels))}
index_graph={i:[] for i in range(1,len(graph_labels)+1)}
with open(filename_adj) as f:
for line in f:
line=line.strip("\n").split(",")
Expand All @@ -50,7 +53,8 @@ def read_graphfile(datadir, dataname):

# add features and labels
G.graph['label'] = graph_labels[i]

for u in G.nodes():
G.node[u]['label'] = node_labels[u]

graphs[i] = G

Expand All @@ -66,41 +70,6 @@ def read_graphfile(datadir, dataname):
mapping[n]=it
it+=1


graphs[i] = nx.relabel_nodes(graphs[i], mapping)
return graphs,adj_list, np.array(node_labels),np.array(graph_labels),graph_indic,index_graph

def load_data(datadir, dataname):

graphs,adj_list, node_labels,graph_labels,graph_indic,index_graph = read_graphfile(datadir, dataname)
graph_indic_array=np.sort([v for k,v in graph_indic.iteritems()])

ind=range(1,1+len(graphs))
np.random.shuffle(ind)
training_set=ind[:1000]
test_set=ind[1000:]

n_classes=np.max(node_labels)
train_features=[None]*len(training_set)
it=0
for i in training_set:
cand_nodes=index_graph[i]
train_features[it]=np.zeros((len(cand_nodes),n_classes+1))
train_features[it][np.arange(len(cand_nodes)), node_labels[cand_nodes]]=1
it+=1


test_features=[None]*len(test_set)
it=0
for i in test_set:
cand_nodes=np.where(graph_indic_array==i)[0]
#test_features[it]=np.eye(n_classes+1)[ node_labels[cand_nodes]]
test_features[it]=np.zeros((len(cand_nodes),n_classes+1))
test_features[it][np.arange(len(cand_nodes)), node_labels[cand_nodes]]=1
it+=1


train_graphs,train_targets=[graphs[i] for i in training_set], [graph_labels[i-1] for i in training_set]
test_graphs,test_targets=[graphs[i] for i in test_set], [graph_labels[i-1] for i in test_set]
return train_graphs,np.array(train_targets),train_features, test_graphs,np.array(test_targets),test_features

54 changes: 31 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import gen.feat as featgen
import gen.data as datagen
from graph_sampler import GraphSampler
import load_data
import util

def synthetic_task_test(dataset, model, args):
Expand Down Expand Up @@ -64,52 +65,59 @@ def synthetic_task_train(dataset, model, args, same_feat=True):

return model

def synthetic_task1(args, export_graphs=False):

# data
graphs1 = datagen.gen_ba(range(40, 60), range(4, 5), 500,
featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float) *0.5))
for G in graphs1:
G.graph['label'] = 0
if export_graphs:
util.draw_graph_list(graphs1[:16], 4, 4, 'figs/ba')

graphs2 = datagen.gen_2community_ba(range(20, 30), range(4, 5), 500, 0.3,
[featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)*0.5)])
for G in graphs2:
G.graph['label'] = 1
if export_graphs:
util.draw_graph_list(graphs2[:16], 4, 4, 'figs/ba2')

graphs = graphs1 + graphs2
def prepare_data(graphs, train_ratio):
random.shuffle(graphs)

train_idx = int(len(graphs) * 0.8)
train_idx = int(len(graphs) * train_ratio)
train_graphs = graphs[:train_idx]
test_graphs = graphs[train_idx:]
print('Num training graphs: ', len(train_graphs),
'; Num testing graphs: ', len(test_graphs))

# minibatch
dataset_sampler = GraphSampler(train_graphs)
dataset_loader = torch.utils.data.DataLoader(
train_dataset_loader = torch.utils.data.DataLoader(
dataset_sampler,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)

dataset_sampler = GraphSampler(test_graphs)
dataset_loader = torch.utils.data.DataLoader(
test_dataset_loader = torch.utils.data.DataLoader(
dataset_sampler,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers)

return train_dataset_loader, test_dataset_loader

def synthetic_task1(args, export_graphs=False):

# data
graphs1 = datagen.gen_ba(range(40, 60), range(4, 5), 500,
featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float) *0.5))
for G in graphs1:
G.graph['label'] = 0
if export_graphs:
util.draw_graph_list(graphs1[:16], 4, 4, 'figs/ba')

graphs2 = datagen.gen_2community_ba(range(20, 30), range(4, 5), 500, 0.3,
[featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)*0.5)])
for G in graphs2:
G.graph['label'] = 1
if export_graphs:
util.draw_graph_list(graphs2[:16], 4, 4, 'figs/ba2')

graphs = graphs1 + graphs2

train_dataset, test_dataset = prepare_data(graphs, 0.8)
model = encoders.GcnEncoderGraph(args.input_dim, args.hidden_dim, args.output_dim, 2, 2).cuda()
synthetic_task_train(dataset_loader, model, args)
synthetic_task_test(dataset_loader, model, args)
synthetic_task_train(train_dataset, model, args)
synthetic_task_test(train_dataset, model, args)

def benchmark_task(args):
graphs, _ = load_data.read_graphfile(args.datadir, args.bmname)
print('len', len(graphs))

def arg_parse():
parser = argparse.ArgumentParser(description='GraphPool arguments.')
Expand Down

0 comments on commit 361ade7

Please sign in to comment.