Skip to content

Commit

Permalink
some comments (dmlc#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Jan 6, 2020
1 parent 6d0a9c4 commit e705118
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 63 deletions.
4 changes: 2 additions & 2 deletions examples/pytorch/recommendation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ First, download and extract from https://dgl.ai.s3.us-east-2.amazonaws.com/datas
One can then run the following to train PinSage on MovieLens-1M:

```bash
python3 main.py --opt Adam --lr 1e-3 --sched none --sgd-switch 25
python3 main.py --opt Adam --lr 1e-3
```

One can also incorporate user and movie features into training:

```bash
python3 main.py --opt Adam --lr 1e-3 --sched none --sgd-switch 25 --use-feature
python3 main.py --opt Adam --lr 1e-3 --use-feature
```

Currently, performance of PinSage on MovieLens-1M has the best mean reciprocal rank of
Expand Down
78 changes: 33 additions & 45 deletions examples/pytorch/recommendation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
pickle.dump(ml, f)

g = ml.g
neighbors = ml.user_neighbors + ml.movie_neighbors

n_hidden = 100
n_layers = args.layers
Expand All @@ -48,10 +47,6 @@
n_negs = args.n_negs
hard_neg_prob = args.hard_neg_prob

sched_lambda = {
'none': lambda epoch: 1,
'decay': lambda epoch: max(0.98 ** epoch, 1e-4),
}
loss_func = {
'hinge': lambda diff: (diff + margin).clamp(min=0).mean(),
'bpr': lambda diff: (1 - torch.sigmoid(-diff)).mean(),
Expand All @@ -67,7 +62,6 @@
G=g,
))
opt = getattr(torch.optim, args.opt)(model.parameters(), lr=args.lr)
sched = torch.optim.lr_scheduler.LambdaLR(opt, sched_lambda[args.sched])


def forward(model, g_prior, nodeset, train=True):
Expand All @@ -78,48 +72,37 @@ def forward(model, g_prior, nodeset, train=True):
return model(g_prior, nodeset)


def filter_nid(nids, nid_from):
nids = [nid.numpy() for nid in nids]
nid_from = nid_from.numpy()
np_mask = np.logical_and(*[np.isin(nid, nid_from) for nid in nids])
return [torch.from_numpy(nid[np_mask]) for nid in nids]


def runtrain(g_prior_edges, g_train_edges, train):
def runtrain(g_train_bases, g_train_pairs, train):
global opt
if train:
model.train()
else:
model.eval()

g_prior_src, g_prior_dst = g.find_edges(g_prior_edges)
g_prior = DGLGraph()
g_prior.add_nodes(g.number_of_nodes())
g_prior.add_edges(g_prior_src, g_prior_dst)
g_prior.ndata.update({k: cuda(v) for k, v in g.ndata.items()})
edge_batches = g_train_edges[torch.randperm(g_train_edges.shape[0])].split(batch_size)
g_prior = g.edge_subgraph(g_train_bases, preserve_nodes=True)
g_prior.copy_from_parent()

# generate batches of training pairs
edge_batches = g_train_pairs[torch.randperm(g_train_pairs.shape[0])].split(batch_size)

with tqdm.tqdm(edge_batches) as tq:
sum_loss = 0
sum_acc = 0
count = 0
for batch_id, batch in enumerate(tq):
count += batch.shape[0]
# Get source (user) and destination (item) nodes, as well as negative items
src, dst = g.find_edges(batch)
dst_neg = []
for i in range(len(dst)):
if np.random.rand() < hard_neg_prob:
nb = torch.LongTensor(neighbors[dst[i].item()])
mask = ~(g.has_edges_between(nb, src[i].item()).byte())
dst_neg.append(np.random.choice(nb[mask].numpy(), n_negs))
else:
dst_neg.append(np.random.randint(
len(ml.user_ids), len(ml.user_ids) + len(ml.movie_ids), n_negs))
dst_neg.append(np.random.randint(
len(ml.user_ids), len(ml.user_ids) + len(ml.movie_ids), n_negs))
dst_neg = torch.LongTensor(dst_neg)
dst = dst.view(-1, 1).expand_as(dst_neg).flatten()
src = src.view(-1, 1).expand_as(dst_neg).flatten()
dst_neg = dst_neg.flatten()

# make sure that the source/destination/negative nodes have successors
mask = (g_prior.in_degrees(dst_neg) > 0) & \
(g_prior.in_degrees(dst) > 0) & \
(g_prior.in_degrees(src) > 0)
Expand All @@ -133,6 +116,7 @@ def runtrain(g_prior_edges, g_train_edges, train):
src_size, dst_size, dst_neg_size = \
src.shape[0], dst.shape[0], dst_neg.shape[0]

# get representations and compute losses
h_src, h_dst, h_dst_neg = (
forward(model, g_prior, nodeset, train)
.split([src_size, dst_size, dst_neg_size]))
Expand Down Expand Up @@ -163,18 +147,16 @@ def runtrain(g_prior_edges, g_train_edges, train):
return avg_loss, avg_acc


def runtest(g_prior_edges, validation=True):
def runtest(g_train_bases, ml, validation=True):
model.eval()

n_users = len(ml.users.index)
n_items = len(ml.movies.index)

g_prior_src, g_prior_dst = g.find_edges(g_prior_edges)
g_prior = DGLGraph()
g_prior.add_nodes(g.number_of_nodes())
g_prior.add_edges(g_prior_src, g_prior_dst)
g_prior.ndata.update({k: cuda(v) for k, v in g.ndata.items()})
g_prior = g.edge_subgraph(g_train_bases, preserve_nodes=True)
g_prior.copy_from_parent()

# Pre-compute the representations of users and items
hs = []
with torch.no_grad():
with tqdm.trange(n_users + n_items) as tq:
Expand All @@ -189,6 +171,10 @@ def runtest(g_prior_edges, validation=True):
with torch.no_grad():
with tqdm.trange(n_users) as tq:
for u_nid in tq:
# For each user, exclude the items appearing in
# (1) the training set, and
# (2) either the validation set when testing, or the test set when
# validating.
uid = ml.user_ids[u_nid]
pids_exclude = ml.ratings[
(ml.ratings['user_id'] == uid) &
Expand All @@ -202,6 +188,7 @@ def runtest(g_prior_edges, validation=True):
p_nids = np.array([ml.movie_ids_invmap[pid] for pid in pids])
p_nids_candidate = np.array([ml.movie_ids_invmap[pid] for pid in pids_candidate])

# compute scores of items and rank them, then compute the MRR.
dst = torch.from_numpy(p_nids) + n_users
src = torch.zeros_like(dst).fill_(u_nid)
h_dst = h[dst]
Expand All @@ -224,31 +211,32 @@ def train():
best_mrr = 0
for epoch in range(500):
ml.refresh_mask()
g_prior_edges = g.filter_edges(lambda edges: edges.data['prior'])
g_train_edges = g.filter_edges(lambda edges: edges.data['train'] & ~edges.data['inv'])
g_prior_train_edges = g.filter_edges(

# In training, we perform message passing on edges marked with 'prior', and
# do link prediction on edges marked with 'train'.
# 'prior' and 'train' are disjoint so that the training pairs can not pass
# messages between each other.
# 'prior' and 'train' are re-generated everytime with ml.refresh_mask() above.
g_train_bases = g.filter_edges(lambda edges: edges.data['prior'])
g_train_pairs = g.filter_edges(lambda edges: edges.data['train'] & ~edges.data['inv'])
# In testing we perform message passing on both 'prior' and 'train' edges.
g_test_bases = g.filter_edges(
lambda edges: edges.data['prior'] | edges.data['train'])

print('Epoch %d validation' % epoch)
with torch.no_grad():
valid_mrr = runtest(g_prior_train_edges, True)
valid_mrr = runtest(g_test_bases, ml, True)
if best_mrr < valid_mrr.mean():
best_mrr = valid_mrr.mean()
torch.save(model.state_dict(), 'model.pt')
print(pd.Series(valid_mrr).describe())
print('Epoch %d test' % epoch)
with torch.no_grad():
test_mrr = runtest(g_prior_train_edges, False)
test_mrr = runtest(g_test_bases, ml, False)
print(pd.Series(test_mrr).describe())

print('Epoch %d train' % epoch)
runtrain(g_prior_edges, g_train_edges, True)

if epoch == args.sgd_switch:
opt = torch.optim.SGD(model.parameters(), lr=0.6)
sched = torch.optim.lr_scheduler.LambdaLR(opt, sched_lambda['decay'])
elif epoch < args.sgd_switch:
sched.step()
runtrain(g_train_bases, g_train_pairs, True)


if __name__ == '__main__':
Expand Down
16 changes: 0 additions & 16 deletions examples/pytorch/recommendation/rec/datasets/movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self, directory):

self.data_split()
self.build_graph()
self.find_neighbors(0.2, 2000, 1000)

def split_user(self, df, filter_counts=False):
df_new = df.copy()
Expand Down Expand Up @@ -176,21 +175,6 @@ def build_graph(self):
data={'inv': torch.ones(self.ratings.shape[0], dtype=torch.uint8)})
self.g = g

def find_neighbors(self, restart_prob, max_nodes, top_T):
# TODO: replace with more efficient PPR estimation
neighbor_probs, neighbors = randomwalk.random_walk_distribution_topt(
self.g, self.g.nodes(), restart_prob, max_nodes, top_T)

self.user_neighbors = []
for i in range(len(self.user_ids)):
user_neighbor = neighbors[i]
self.user_neighbors.append(user_neighbor.tolist())

self.movie_neighbors = []
for i in range(len(self.user_ids), len(self.user_ids) + len(self.movie_ids)):
movie_neighbor = neighbors[i]
self.movie_neighbors.append(movie_neighbor.tolist())

def generate_mask(self):
while True:
ratings = self.ratings.groupby('user_id', group_keys=False).apply(self.split_user)
Expand Down

0 comments on commit e705118

Please sign in to comment.