Skip to content

Commit

Permalink
[Model] Pinsage example that uses sparse embedding update (dmlc#1676)
Browse files Browse the repository at this point in the history
* pinsage example that uses sparse embedding update

* add difference from paper
  • Loading branch information
BarclayII authored Jul 6, 2020
1 parent 9cd0d3f commit 2f6ab43
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 1 deletion.
33 changes: 32 additions & 1 deletion examples/pytorch/pinsage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,35 @@ item embeddings, which are learned as outputs of PinSAGE.
python model.py data.pkl --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 64
```

The HITS@10 is 0.01241, compared to 0.01220 with SLIM with the same dimensionality.
The implementation here also assigns a learnable vector to each item. If your hidden
state size is so large that the learnable vectors cannot fit into GPU, use this script
for sparse embedding update (written with `torch.optim.SparseAdam`) instead:


```
python model_sparse.py data.pkl --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 1024
```

Note that since the embedding update is done on CPU, it will be significantly slower than doing
everything on GPU.

The HITS@10 is 0.01241, compared to 0.01220 with SLIM with the same dimensionality.\

## Difference from the paper

The implementation here is different from what being described in the paper:

1. The paper described a supervised setting where the authors have a ground truth set of which items are
relevant. However, in traditional recommender system datasets we don't have such labels other than
which items are interacted by which users (as well as the user/item's own features). Therefore, I
adapted PinSAGE to an unsupervised setting where I predict whether two items are cointeracted by the
same user.
2. PinSAGE paper explicitly stated that the items do not learnable embeddings of nodes, but directly
express the embeddings as a function of node features. While this is reasonable for rich datasets like
Pinterest's where images and texts are rich enough to distinguish the items from each other, it is
unfortunately not the case for traditional recommender system datasets like MovieLens or Nowplaying-RS
where we only have a bunch of categorical or numeric variables. I found adding a learnable embedding
for each item still helpful for those datasets.
3. The PinSAGE paper directly pass the GNN output to an MLP and make the result the final item
representation. Here, I'm adding the GNN output with the node's own learnable embedding as
the final item representation instead.
147 changes: 147 additions & 0 deletions examples/pytorch/pinsage/model_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pickle
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
import dgl
import tqdm

import layers
import sampler as sampler_module
import evaluation

class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__()

self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims)
self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype)

def forward(self, pos_graph, neg_graph, blocks, item_emb):
h_item = self.get_repr(blocks, item_emb)
pos_score = self.scorer(pos_graph, h_item)
neg_score = self.scorer(neg_graph, h_item)
return (neg_score - pos_score + 1).clamp(min=0)

def get_repr(self, blocks, item_emb):
# project features
h_item = self.proj(blocks[0].srcdata)
h_item_dst = self.proj(blocks[-1].dstdata)

# add to the item embedding itself
h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item)
h_item_dst = h_item_dst + item_emb(blocks[-1].dstdata[dgl.NID].cpu()).to(h_item_dst)

return h_item_dst + self.sage(blocks, h_item)

def train(dataset, args):
g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']

device = torch.device(args.device)

# Prepare torchtext dataset and vocabulary
fields = {}
examples = []
for key, texts in item_texts.items():
fields[key] = torchtext.data.Field(include_lengths=True, lower=True, batch_first=True)
for i in range(g.number_of_nodes(item_ntype)):
example = torchtext.data.Example.fromlist(
[item_texts[key][i] for key in item_texts.keys()],
[(key, fields[key]) for key in item_texts.keys()])
examples.append(example)
textset = torchtext.data.Dataset(examples, fields)
for key, field in fields.items():
field.build_vocab(getattr(textset, key))
#field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d')

# Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size)
neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
dataloader = DataLoader(
batch_sampler,
collate_fn=collator.collate_train,
num_workers=args.num_workers)
dataloader_test = DataLoader(
torch.arange(g.number_of_nodes(item_ntype)),
batch_size=args.batch_size,
collate_fn=collator.collate_test,
num_workers=args.num_workers)
dataloader_it = iter(dataloader)

# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device)
item_emb = nn.Embedding(g.number_of_nodes(item_ntype), args.hidden_dims, sparse=True)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr)

# For each batch of head-tail-negative triplets...
for epoch_id in range(args.num_epochs):
model.train()
for batch_id in tqdm.trange(args.batches_per_epoch):
pos_graph, neg_graph, blocks = next(dataloader_it)
# Copy to GPU
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)

loss = model(pos_graph, neg_graph, blocks, item_emb).mean()
opt.zero_grad()
opt_emb.zero_grad()
loss.backward()
opt.step()
opt_emb.step()

# Evaluate
model.eval()
with torch.no_grad():
item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
h_item_batches = []
for blocks in tqdm.tqdm(dataloader_test):
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)

h_item_batches.append(model.get_repr(blocks, item_emb))
h_item = torch.cat(h_item_batches, 0)

print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size))

if __name__ == '__main__':
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str)
parser.add_argument('--random-walk-length', type=int, default=2)
parser.add_argument('--random-walk-restart-prob', type=float, default=0.5)
parser.add_argument('--num-random-walks', type=int, default=10)
parser.add_argument('--num-neighbors', type=int, default=3)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--hidden-dims', type=int, default=16)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0"
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--batches-per-epoch', type=int, default=20000)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('-k', type=int, default=10)
args = parser.parse_args()

# Load dataset
with open(args.dataset_path, 'rb') as f:
dataset = pickle.load(f)
train(dataset, args)

0 comments on commit 2f6ab43

Please sign in to comment.