-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Temporal Graph Neural Network (dmlc#2636)
* Add hgat example * Add experiment * Clean code * clear the code * Add index in README * Add index in README * Add index in README * Add index in README * Add index in README * Add index in README * Change the code title and folder name * Ready to merge * Prepare for rebase and change message passing function * use git ignore to handle empty file * change file permission to resolve empty file * Change permission * change file mode * Finish Coding * working code cpu * pyg compare * Accelerate with batching * FastMode Enabled * update readme * Update README.md * refractor code * Fix Bug * add a simple temporal sampling method * test results * add train speed result * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * fix bug * Fixed Grammar and Format Issue Co-authored-by: Chen <[email protected]> Co-authored-by: Tianjun Xiao <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: WangXuhongCN <[email protected]> Co-authored-by: WangXuhongCN <[email protected]>
- Loading branch information
1 parent
6999f88
commit d04d59e
Showing
7 changed files
with
1,788 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Temporal Graph Neural Network (TGN) | ||
|
||
## DGL Implementation of tgn paper. | ||
|
||
This DGL examples implements the GNN mode proposed in the paper [TemporalGraphNeuralNetwork](https://arxiv.org/abs/2006.10637.pdf) | ||
|
||
## TGN implementor | ||
|
||
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his SDE internship at the AWS Shanghai AI Lab. | ||
|
||
## Graph Dataset | ||
|
||
Jodie Wikipedia Temporal dataset. Dataset summary: | ||
|
||
- Num Nodes: 9227 | ||
- Num Edges: 157, 474 | ||
- Num Edge Features: 172 | ||
- Edge Feature type: LIWC | ||
- Time Span: 30 days | ||
- Chronological Split: Train: 70% Valid: 15% Test: 15% | ||
|
||
Jodie Reddit Temporal dataset. Dataset summary: | ||
|
||
- Num Nodes: 11,000 | ||
- Num Edges: 672, 447 | ||
- Num Edge Features: 172 | ||
- Edge Feature type: LIWC | ||
- Time Span: 30 days | ||
- Chronological Split: Train: 70% Valid: 15% Test: 15% | ||
|
||
## How to run example files | ||
|
||
In tgn folder, run | ||
|
||
**please use `train.py`** | ||
|
||
```python | ||
python train.py --dataset wikipedia | ||
``` | ||
|
||
If you want to run in fast mode: | ||
|
||
```python | ||
python train.py --dataset wikipedia --fast_mode | ||
``` | ||
|
||
If you want to run in simple mode: | ||
|
||
```python | ||
python train.py --dataset wikipedia --simple_mode | ||
``` | ||
|
||
If you want to change memory updating module: | ||
|
||
```python | ||
python train.py --dataset wikipedia --memory_updater [rnn/gru] | ||
``` | ||
|
||
## Performance | ||
|
||
#### Without New Node in test set | ||
|
||
| Models/Datasets | Wikipedia | Reddit | | ||
| --------------- | ------------------ | ---------------- | | ||
| TGN simple mode | AP: 98.5 AUC: 98.9 | AP: N/A AUC: N/A | | ||
| TGN fast mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A | | ||
| TGN | AP: 98.9 AUC: 98.5 | AP: N/A AUC: N/A | | ||
|
||
#### With New Node in test set | ||
|
||
| Models/Datasets | Wikipedia | Reddit | | ||
| --------------- | ------------------- | ---------------- | | ||
| TGN simple mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A | | ||
| TGN fast mode | AP: 98.0 AUC: 98.4 | AP: N/A AUC: N/A | | ||
| TGN | AP: 98.2 AUC: 98.1 | AP: N/A AUC: N/A | | ||
|
||
## Training Speed / Batch | ||
Intel E5 2cores, Tesla K80, Wikipedia Dataset | ||
|
||
| Models/Datasets | Wikipedia | Reddit | | ||
| --------------- | --------- | -------- | | ||
| TGN simple mode | 0.3s | N/A | | ||
| TGN fast mode | 0.28s | N/A | | ||
| TGN | 1.3s | N/A | | ||
|
||
### Details explained | ||
|
||
**What is Simple Mode** | ||
|
||
Simple Temporal Sampler just choose the edges that happen before the current timestamp and build the subgraph of the corresponding nodes. | ||
And then the simple sampler uses the static graph neighborhood sampling methods. | ||
|
||
**What is Fast Mode** | ||
|
||
Normally temporal encoding needs each node to use incoming time frame as current time which might lead to two nodes have multiple interactions within the same batch need to maintain multiple embedding features which slow down the batching process to avoid feature duplication, fast mode enables fast batching since it uses last memory update time in the last batch as temporal encoding benchmark for each node. Also within each batch, all interaction between two nodes are predicted using the same set of embedding feature | ||
|
||
**What is New Node test** | ||
|
||
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import ssl | ||
from six.moves import urllib | ||
|
||
import pandas as pd | ||
import numpy as np | ||
|
||
import torch | ||
import dgl | ||
|
||
# === Below data preprocessing code are based on | ||
# https://github.com/twitter-research/tgn | ||
|
||
# Preprocess the raw data split each features | ||
|
||
def preprocess(data_name): | ||
u_list, i_list, ts_list, label_list = [], [], [], [] | ||
feat_l = [] | ||
idx_list = [] | ||
|
||
with open(data_name) as f: | ||
s = next(f) | ||
for idx, line in enumerate(f): | ||
e = line.strip().split(',') | ||
u = int(e[0]) | ||
i = int(e[1]) | ||
|
||
ts = float(e[2]) | ||
label = float(e[3]) # int(e[3]) | ||
|
||
feat = np.array([float(x) for x in e[4:]]) | ||
|
||
u_list.append(u) | ||
i_list.append(i) | ||
ts_list.append(ts) | ||
label_list.append(label) | ||
idx_list.append(idx) | ||
|
||
feat_l.append(feat) | ||
return pd.DataFrame({'u': u_list, | ||
'i': i_list, | ||
'ts': ts_list, | ||
'label': label_list, | ||
'idx': idx_list}), np.array(feat_l) | ||
|
||
# Re index nodes for DGL convience | ||
def reindex(df, bipartite=True): | ||
new_df = df.copy() | ||
if bipartite: | ||
assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) | ||
assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) | ||
|
||
upper_u = df.u.max() + 1 | ||
new_i = df.i + upper_u | ||
|
||
new_df.i = new_i | ||
new_df.u += 1 | ||
new_df.i += 1 | ||
new_df.idx += 1 | ||
else: | ||
new_df.u += 1 | ||
new_df.i += 1 | ||
new_df.idx += 1 | ||
|
||
return new_df | ||
|
||
# Save edge list, features in different file for data easy process data | ||
def run(data_name, bipartite=True): | ||
PATH = './data/{}.csv'.format(data_name) | ||
OUT_DF = './data/ml_{}.csv'.format(data_name) | ||
OUT_FEAT = './data/ml_{}.npy'.format(data_name) | ||
OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) | ||
|
||
df, feat = preprocess(PATH) | ||
new_df = reindex(df, bipartite) | ||
|
||
empty = np.zeros(feat.shape[1])[np.newaxis, :] | ||
feat = np.vstack([empty, feat]) | ||
|
||
max_idx = max(new_df.u.max(), new_df.i.max()) | ||
rand_feat = np.zeros((max_idx + 1, 172)) | ||
|
||
new_df.to_csv(OUT_DF) | ||
np.save(OUT_FEAT, feat) | ||
np.save(OUT_NODE_FEAT, rand_feat) | ||
|
||
# === code from twitter-research-tgn end === | ||
|
||
# If you have new dataset follow by same format in Jodie, | ||
# you can directly use name to retrieve dataset | ||
|
||
def TemporalDataset(dataset): | ||
if not os.path.exists('./data/{}.bin'.format(dataset)): | ||
if not os.path.exists('./data/{}.csv'.format(dataset)): | ||
if not os.path.exists('./data'): | ||
os.mkdir('./data') | ||
|
||
url = 'https://snap.stanford.edu/jodie/{}.csv'.format(dataset) | ||
print("Start Downloading File....") | ||
context = ssl._create_unverified_context() | ||
data = urllib.request.urlopen(url, context=context) | ||
with open("./data/{}.csv".format(dataset), "wb") as handle: | ||
handle.write(data.read()) | ||
|
||
print("Start Process Data ...") | ||
run(dataset) | ||
raw_connection = pd.read_csv('./data/ml_{}.csv'.format(dataset)) | ||
raw_feature = np.load('./data/ml_{}.npy'.format(dataset)) | ||
# -1 for re-index the node | ||
src = raw_connection['u'].to_numpy()-1 | ||
dst = raw_connection['i'].to_numpy()-1 | ||
# Create directed graph | ||
g = dgl.graph((src, dst)) | ||
g.edata['timestamp'] = torch.from_numpy( | ||
raw_connection['ts'].to_numpy()) | ||
g.edata['label'] = torch.from_numpy(raw_connection['label'].to_numpy()) | ||
g.edata['feats'] = torch.from_numpy(raw_feature[1:, :]).float() | ||
dgl.save_graphs('./data/{}.bin'.format(dataset), [g]) | ||
else: | ||
print("Data is exist directly loaded.") | ||
gs, _ = dgl.load_graphs('./data/{}.bin'.format(dataset)) | ||
g = gs[0] | ||
return g | ||
|
||
def TemporalWikipediaDataset(): | ||
# Download the dataset | ||
return TemporalDataset('wikipedia') | ||
|
||
def TemporalRedditDataset(): | ||
return TemporalDataset('reddit') |
Oops, something went wrong.