Skip to content

Commit

Permalink
motify negative sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Anonymized committed Mar 23, 2019
1 parent 2babf1c commit e9bf240
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 22 deletions.
Binary file modified src/__pycache__/dataCenter.cpython-35.pyc
Binary file not shown.
Binary file modified src/__pycache__/main.cpython-35.pyc
Binary file not shown.
Binary file modified src/__pycache__/models.cpython-35.pyc
Binary file not shown.
Binary file modified src/__pycache__/utils.cpython-35.pyc
Binary file not shown.
6 changes: 4 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
parser.add_argument('--dataSet', type=str, default='cora')
parser.add_argument('--agg_func', type=str, default='MEAN')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--b_sz', type=int, default=50)
parser.add_argument('--b_sz', type=int, default=20)
parser.add_argument('--seed', type=int, default=824)
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
Expand Down Expand Up @@ -56,9 +56,11 @@
classification = Classification(config['setting.hidden_emb_size'], num_labels)
classification.to(device)

unsupervised_loss = UnsupervisedLoss(getattr(dataCenter, ds+'_adj_lists'), getattr(dataCenter, ds+'_train'))

for epoch in range(args.epochs):
print('----------------------EPOCH %d-----------------------' % epoch)
apply_model(dataCenter, ds, graphSage, classification, args.b_sz, device)
apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, args.b_sz, device)
evaluate(dataCenter, ds, graphSage, classification, args.b_sz, device)


60 changes: 60 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,66 @@ def forward(self, embeds):
logists = torch.log_softmax(torch.mm(embeds,self.weight), 1)
return logists

class UnsupervisedLoss(object):
"""docstring for UnsupervisedLoss"""
def __init__(self, adj_lists, train_nodes):
super(UnsupervisedLoss, self).__init__()
self.N_WALKS = 2
self.WALK_LEN = 3
self.adj_lists = adj_lists
self.train_nodes = train_nodes

self.positive_pairs = []
self.negtive_pairs = []
self.unique_nodes_batch = []

def forward(self):
pass

def extend_nodes(self, nodes):
self.get_positive_nodes(nodes)
# print(self.positive_pairs)
self.get_negtive_nodes(nodes)
# print(self.negtive_pairs)
self.unique_nodes_batch = set([i for x in self.positive_pairs for i in x]) | set([i for x in self.negtive_pairs for i in x])
return self.unique_nodes_batch

def get_positive_nodes(self, nodes):
return self._run_random_walks(nodes)

def get_negtive_nodes(self, nodes):
self.negtive_pairs = []
for node in nodes:
neighbors = set([node])
frontier = set([node])
for i in range(self.WALK_LEN):
current = set()
for outer in frontier:
current |= self.adj_lists[int(outer)]
frontier = current - neighbors
neighbors |= current
far_nodes = set(self.train_nodes) - neighbors
neg_samples = random.sample(far_nodes, self.N_WALKS*self.WALK_LEN) if self.N_WALKS*self.WALK_LEN < len(far_nodes) else far_nodes
self.negtive_pairs.extend([(node, neg_node) for neg_node in neg_samples])
return self.negtive_pairs

def _run_random_walks(self, nodes):
self.positive_pairs = []
for node in nodes:
if len(self.adj_lists[int(node)]) == 0:
continue
for i in range(self.N_WALKS):
curr_node = node
for j in range(self.WALK_LEN):
neighs = self.adj_lists[int(curr_node)]
next_node = random.choice(list(neighs))
# self co-occurrences are useless
if curr_node != node and curr_node in self.train_nodes:
self.positive_pairs.append((node,curr_node))
curr_node = next_node
return self.positive_pairs


class SageLayer(nn.Module):
"""
Encodes a node's using 'convolutional' GraphSage approach
Expand Down
30 changes: 10 additions & 20 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from sklearn.utils import shuffle
from sklearn.metrics import f1_score

import torch.nn as nn
import numpy as np

def evaluate(dataCenter, ds, graphSage, classification, b_sz, device):
test_nodes = getattr(dataCenter, ds+'_test')
Expand Down Expand Up @@ -43,7 +45,7 @@ def evaluate(dataCenter, ds, graphSage, classification, b_sz, device):
for param in params:
param.requires_grad = True

def apply_model(dataCenter, ds, graphSage, classification, b_sz, device):
def apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, b_sz, device):
test_nodes = getattr(dataCenter, ds+'_test')
val_nodes = getattr(dataCenter, ds+'_val')
train_nodes = getattr(dataCenter, ds+'_train')
Expand All @@ -67,14 +69,17 @@ def apply_model(dataCenter, ds, graphSage, classification, b_sz, device):

batches = math.ceil(len(train_nodes) / b_sz)

visited_nodes = set()
for index in range(batches):
nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz]
nodes_batch = np.asarray(list(unsupervised_loss.extend_nodes(nodes_batch)))
visited_nodes |= set(nodes_batch)
labels_batch = labels[nodes_batch]
embs_batch = graphSage(nodes_batch)
logists = classification(embs_batch)
loss = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
loss /= len(nodes_batch)
print(loss.item())
print('Step {}, Loss: {:.4f}, Dealed Nodes [{}/{}] '.format(index, loss.item(), len(visited_nodes), len(train_nodes)))
loss.backward()
for model in models:
nn.utils.clip_grad_norm_(model.parameters(), 5)
Expand All @@ -84,24 +89,9 @@ def apply_model(dataCenter, ds, graphSage, classification, b_sz, device):
for model in models:
model.zero_grad()

N_WALKS = 10
WALK_LEN = 10
def run_random_walks(G, nodes, num_walks=N_WALKS):
pairs = []
for count, node in enumerate(nodes):
if G.degree(node) == 0:
continue
for i in range(num_walks):
curr_node = node
for j in range(WALK_LEN):
next_node = random.choice(G.neighbors(curr_node))
# self co-occurrences are useless
if curr_node != node:
pairs.append((node,curr_node))
curr_node = next_node
if count % 1000 == 0:
print("Done walks for", count, "nodes")
return pairs
if visited_nodes == set(train_nodes):
return


# def run_cora(device, dataCenter, data):
# feat_data, labels, adj_lists = data
Expand Down

0 comments on commit e9bf240

Please sign in to comment.