Skip to content

Commit

Permalink
load data
Browse files Browse the repository at this point in the history
  • Loading branch information
RexYing committed Apr 18, 2018
1 parent 5de68e7 commit 1523141
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 3 deletions.
106 changes: 106 additions & 0 deletions load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import scipy as sc
import networkx as nx
from graph_functions import *


def read_graphfile(datadir, dataname):
filename_graph_indic=datadir+dataname+"/"+dataname+"_graph_indicator.txt"
graph_indic={}

with open(filename_graph_indic) as f:
i=1
for line in f:
line=line.strip("\n")
graph_indic[i]=int(line)
i+=1
filename_nodes=datadir+dataname+"/"+dataname+"_node_labels.txt"
node_labels=[]
with open(filename_nodes) as f:
for line in f:
line=line.strip("\n")
node_labels+=[int(line)]

filename_graphs=datadir+dataname+"/"+dataname+"_graph_labels.txt"
graph_labels=[]
with open(filename_graphs) as f:
for line in f:
line=line.strip("\n")
graph_labels.append(int(line))

adj_list={i:[] for i in range(1,len(graph_labels)+1)}
filename_adj="../data/"+dataname+"/"+dataname+"_A.txt"
gen_list=[]
index_graph={k:[] for k in range(1,1+len(graph_labels))}
with open(filename_adj) as f:
for line in f:
line=line.strip("\n").split(",")
gen_list.append((int(line[0].strip(" ")),int(line[1].strip(" "))))
e0,e1=(int(line[0].strip(" ")),int(line[1].strip(" ")))
adj_list[graph_indic[e0]].append((e0,e1))
index_graph[graph_indic[e0]]+=[e0,e1]
for k in index_graph.keys():
index_graph[k]=[u-1 for u in set(index_graph[k])]
#for i in range(1,len(graph_labels)+1):
# adj_list[i]=[(e[0],e[1]) for e in gen_list if graph_indic[e[0]]==i or graph_indic[e[1]]==i]
print "check ", np.sum([len(adj_list[i]) for i in adj_list.keys()]),len(gen_list)
graphs={}
for i in range(1,1+len(adj_list)):
G=nx.from_edgelist(adj_list[i])

# add features and labels
G.graph['label'] = graph_labels[i]


graphs[i] = G

# relabeling
mapping={}
it=0
if float(nx.__version__)<2.0:
for n in graphs[i].nodes():
mapping[n]=it
it+=1
else:
for n in graphs[i].nodes:
mapping[n]=it
it+=1


graphs[i] = nx.relabel_nodes(graphs[i], mapping)
return graphs,adj_list, np.array(node_labels),np.array(graph_labels),graph_indic,index_graph

def load_data(datadir, dataname):

graphs,adj_list, node_labels,graph_labels,graph_indic,index_graph = read_graphfile(datadir, dataname)
graph_indic_array=np.sort([v for k,v in graph_indic.iteritems()])

ind=range(1,1+len(graphs))
np.random.shuffle(ind)
training_set=ind[:1000]
test_set=ind[1000:]

n_classes=np.max(node_labels)
train_features=[None]*len(training_set)
it=0
for i in training_set:
cand_nodes=index_graph[i]
train_features[it]=np.zeros((len(cand_nodes),n_classes+1))
train_features[it][np.arange(len(cand_nodes)), node_labels[cand_nodes]]=1
it+=1


test_features=[None]*len(test_set)
it=0
for i in test_set:
cand_nodes=np.where(graph_indic_array==i)[0]
#test_features[it]=np.eye(n_classes+1)[ node_labels[cand_nodes]]
test_features[it]=np.zeros((len(cand_nodes),n_classes+1))
test_features[it][np.arange(len(cand_nodes)), node_labels[cand_nodes]]=1
it+=1


train_graphs,train_targets=[graphs[i] for i in training_set], [graph_labels[i-1] for i in training_set]
test_graphs,test_targets=[graphs[i] for i in test_set], [graph_labels[i-1] for i in test_set]
return train_graphs,np.array(train_targets),train_features, test_graphs,np.array(test_targets),test_features

12 changes: 9 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def synthetic_task_test(dataset, model, args):
preds = np.hstack(preds)

print("Validation F1:", metrics.f1_score(labels, preds, average="micro"))
print(labels)
print(preds)
print("Validation prec:", metrics.precision_score(labels, preds))
print("Validation recall:", metrics.recall_score(labels, preds))

Expand Down Expand Up @@ -110,12 +108,19 @@ def synthetic_task1(args, export_graphs=False):
model = encoders.GcnEncoderGraph(args.input_dim, args.hidden_dim, args.output_dim, 2, 2).cuda()
synthetic_task_train(dataset_loader, model, args)
synthetic_task_test(dataset_loader, model, args)

def benchmark_task(args):

def arg_parse():
parser = argparse.ArgumentParser(description='GraphPool arguments.')
io_parser = parser.add_mutually_exclusive_group(required=False)
io_parser.add_argument('--dataset', dest='dataset',
help='Input dataset.')
benchmark_parser = io_parser.add_mutually_exclusive_group(required=False)
benchmark_parser.add_argument('--datadir', dest='datadir',
help='Directory where benchmark is located')
benchmark_parser.add_argument('--bmname', dest='bmname',
help='Name of the benchmark dataset')

parser.add_argument('--cuda', dest='cuda',
help='CUDA.')
Expand All @@ -136,7 +141,8 @@ def arg_parse():
parser.add_argument('--output_dim', dest='output_dim', type=int,
help='Output dimension')

parser.set_defaults(cuda='1',
parser.set_defaults(dataset='synthetic1',
cuda='1',
feature_type='default',
lr=0.001,
batch_size=10,
Expand Down

0 comments on commit 1523141

Please sign in to comment.