Skip to content

Commit

Permalink
[Bugfix] Fix PinSAGE Benchmark (dmlc#4058)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update dgl.data.rst

* CI
  • Loading branch information
mufeili authored May 27, 2022
1 parent 7a065a9 commit bef9930
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 39 deletions.
30 changes: 8 additions & 22 deletions benchmarks/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_graph(name, format = None):
g_list, _ = dgl.load_graphs(bin_path)
g = g_list[0]
else:
# the original node IDs of friendster are not consecutive, so we compact it
# the original node IDs of friendster are not consecutive, so we compact it
g = dgl.compact_graphs(get_friendster()).formats(format)
dgl.save_graphs(bin_path, [g])
elif name == "reddit":
Expand Down Expand Up @@ -253,24 +253,16 @@ def __getitem__(self, idx):

def load_nowplaying_rs():
import torchtext.legacy as torchtext
# follow examples/pytorch/pinsage/README to create nowplaying_rs.pkl
name = 'nowplaying_rs.pkl'
# follow examples/pytorch/pinsage/README to create train_g.bin
name = 'train_g.bin'
dataset_dir = os.path.join(os.getcwd(), 'dataset')
os.symlink('/tmp/dataset/', dataset_dir)

dataset_path = os.path.join(dataset_dir, "nowplaying_rs", name)
# Load dataset
with open(dataset_path, 'rb') as f:
dataset = pickle.load(f)

g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
g_list, _ = dgl.load_graphs(dataset_path)
g = g_list[0]
user_ntype = 'user'
item_ntype = 'track'

# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
Expand All @@ -282,17 +274,11 @@ def load_nowplaying_rs():
# Prepare torchtext dataset and vocabulary
fields = {}
examples = []
for key, texts in item_texts.items():
fields[key] = torchtext.data.Field(
include_lengths=True, lower=True, batch_first=True)
for i in range(g.number_of_nodes(item_ntype)):
example = torchtext.data.Example.fromlist(
[item_texts[key][i] for key in item_texts.keys()],
[(key, fields[key]) for key in item_texts.keys()])
[], [])
examples.append(example)
textset = torchtext.data.Dataset(examples, fields)
for key, field in fields.items():
field.build_vocab(getattr(textset, key))

return PinsageDataset(g, user_ntype, item_ntype, textset)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ Datasets for node classification/regression tasks
BACommunityDataset
TreeCycleDataset
TreeGridDataset
BA2MotifDataset

Edge Prediction Datasets
---------------------------------------
Expand Down Expand Up @@ -88,6 +87,7 @@ Datasets for graph classification/regression tasks
LegacyTUDataset
GINDataset
FakeNewsDataset
BA2MotifDataset

Dataset adapters
-------------------
Expand Down
12 changes: 6 additions & 6 deletions examples/pytorch/pinsage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

1. Download and extract the MovieLens-1M dataset from http://files.grouplens.org/datasets/movielens/ml-1m.zip
into the current directory.
2. Run `python process_movielens1m.py ./ml-1m ./data.pkl`.
Replace `ml-1m` with the directory you put the `.dat` files, and replace `data.pkl` to
any path you wish to put the output pickle file.
2. Run `python process_movielens1m.py ./ml-1m ./data_processed`.
Replace `ml-1m` with the directory you put the `.dat` files, and replace `data_processed` with
any path you wish to put the output files.

### Nowplaying-rs

1. Download and extract the Nowplaying-rs dataset from https://zenodo.org/record/3248543/files/nowplayingrs.zip?download=1
into the current directory.
2. Run `python process_nowplaying_rs.py ./nowplaying_rs_dataset ./data.pkl`
2. Run `python process_nowplaying_rs.py ./nowplaying_rs_dataset ./data_processed`

## Run model

Expand All @@ -31,7 +31,7 @@ interacted. The distance between two items are measured by Euclidean distance o
item embeddings, which are learned as outputs of PinSAGE.

```
python model.py data.pkl --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 64
python model.py data_processed --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 64
```

The implementation here also assigns a learnable vector to each item. If your hidden
Expand All @@ -40,7 +40,7 @@ for sparse embedding update (written with `torch.optim.SparseAdam`) instead:


```
python model_sparse.py data.pkl --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 1024
python model_sparse.py data_processed --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 1024
```

Note that since the embedding update is done on CPU, it will be significantly slower than doing
Expand Down
7 changes: 6 additions & 1 deletion examples/pytorch/pinsage/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
import torchtext
import dgl
import os
import tqdm

import layers
Expand Down Expand Up @@ -137,6 +138,10 @@ def train(dataset, args):
args = parser.parse_args()

# Load dataset
with open(args.dataset_path, 'rb') as f:
data_info_path = os.path.join(args.dataset_path, 'data.pkl')
with open(data_info_path, 'rb') as f:
dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin')
g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0]
train(dataset, args)
7 changes: 6 additions & 1 deletion examples/pytorch/pinsage/model_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
import torchtext
import dgl
import os
import tqdm

import layers
Expand Down Expand Up @@ -142,6 +143,10 @@ def train(dataset, args):
args = parser.parse_args()

# Load dataset
with open(args.dataset_path, 'rb') as f:
data_info_path = os.path.join(args.dataset_path, 'data.pkl')
with open(data_info_path, 'rb') as f:
dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin')
g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0]
train(dataset, args)
10 changes: 6 additions & 4 deletions examples/pytorch/pinsage/process_movielens1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str)
parser.add_argument('output_path', type=str)
parser.add_argument('out_directory', type=str)
args = parser.parse_args()
directory = args.directory
output_path = args.output_path
out_directory = args.out_directory
os.makedirs(out_directory, exist_ok=True)

## Build heterogeneous graph

Expand Down Expand Up @@ -139,8 +140,9 @@

## Dump the graph and the datasets

dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g)

dataset = {
'train-graph': train_g,
'val-matrix': val_matrix,
'test-matrix': test_matrix,
'item-texts': movie_textual_dataset,
Expand All @@ -151,5 +153,5 @@
'item-to-user-type': 'watched-by',
'timestamp-edge-column': 'timestamp'}

with open(output_path, 'wb') as f:
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f:
pickle.dump(dataset, f)
11 changes: 7 additions & 4 deletions examples/pytorch/pinsage/process_nowplaying_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import argparse
import dgl
import pandas as pd
import scipy.sparse as ssp
import pickle
Expand All @@ -14,10 +15,11 @@
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str)
parser.add_argument('output_path', type=str)
parser.add_argument('out_directory', type=str)
args = parser.parse_args()
directory = args.directory
output_path = args.output_path
out_directory = args.out_directory
os.makedirs(out_directory, exist_ok=True)

data = pd.read_csv(os.path.join(directory, 'context_content_features.csv'))
track_feature_cols = list(data.columns[1:13])
Expand Down Expand Up @@ -59,8 +61,9 @@
val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, 'user', 'track', 'listened')

dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g)

dataset = {
'train-graph': train_g,
'val-matrix': val_matrix,
'test-matrix': test_matrix,
'item-texts': {},
Expand All @@ -71,5 +74,5 @@
'item-to-user-type': 'listened-by',
'timestamp-edge-column': 'created_at'}

with open(output_path, 'wb') as f:
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f:
pickle.dump(dataset, f)

0 comments on commit bef9930

Please sign in to comment.