forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* PPIDataset * Revert "PPIDataset" This reverts commit 264bd0c. * gdelt dataset * Update gdelt.py Co-authored-by: xiang song(charlie.song) <[email protected]>
- Loading branch information
1 parent
06ea03d
commit 73b9c6f
Showing
2 changed files
with
140 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,167 @@ | ||
from scipy import io | ||
""" GDELT dataset for temporal graph """ | ||
import numpy as np | ||
import os | ||
import datetime | ||
|
||
from .utils import get_download_dir, download, extract_archive, loadtxt | ||
from ..utils import retry_method_with_fix | ||
from .. import convert | ||
from .dgl_dataset import DGLBuiltinDataset | ||
from .utils import loadtxt, save_info, load_info, _get_dgl_url | ||
from ..convert import graph as dgl_graph | ||
from .. import backend as F | ||
|
||
|
||
class GDELT(object): | ||
""" | ||
The Global Database of Events, Language, and Tone (GDELT) dataset. | ||
This contains events happend all over the world (ie every protest held anywhere | ||
in Russia on a given day is collapsed to a single entry). | ||
class GDELTDataset(DGLBuiltinDataset): | ||
r"""GDELT dataset for event-based temporal graph | ||
This Dataset consists of | ||
events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity). | ||
The Global Database of Events, Language, and Tone (GDELT) dataset. | ||
This contains events happend all over the world (ie every protest held | ||
anywhere in Russia on a given day is collapsed to a single entry). | ||
This Dataset consists ofevents collected from 1/1/2018 to 1/31/2018 | ||
(15 minutes time granularity). | ||
Reference: | ||
- `Recurrent Event Network for Reasoning over Temporal | ||
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_ | ||
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_ | ||
- `Recurrent Event Network for Reasoning over Temporal | ||
Knowledge Graphs` <https://arxiv.org/abs/1904.05530> | ||
- `The Global Database of Events, Language, and Tone (GDELT) ` | ||
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075> | ||
Statistics | ||
---------- | ||
Train examples: 2,304 | ||
Valid examples: 288 | ||
Test examples: 384 | ||
Parameters | ||
------------ | ||
mode: str | ||
Load train/valid/test data. Has to be one of ['train', 'valid', 'test'] | ||
---------- | ||
mode : str | ||
Must be one of ('train', 'valid', 'test'). Default: 'train' | ||
raw_dir : str | ||
Raw file directory to download/contains the input data directory. | ||
Default: ~/.dgl/ | ||
force_reload : bool | ||
Whether to reload the dataset. Default: False | ||
verbose: bool | ||
Whether to print out progress information. Default: True. | ||
Attributes | ||
---------- | ||
start_time : int | ||
Start time of the temporal graph | ||
end_time : int | ||
End time of the temporal graph | ||
is_temporal : bool | ||
Does the dataset contain temporal graphs | ||
Examples | ||
---------- | ||
>>> # get train, valid, test dataset | ||
>>> train_data = GDELTDataset() | ||
>>> valid_data = GDELTDataset(mode='valid') | ||
>>> test_data = GDELTDataset(mode='test') | ||
>>> | ||
>>> # length of train set | ||
>>> train_size = len(train_data) | ||
>>> | ||
>>> for g in train_data: | ||
.... e_feat = g.edata['rel_type'] | ||
.... # your code here | ||
.... | ||
>>> | ||
""" | ||
_url = { | ||
'train': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/train.txt', | ||
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/valid.txt', | ||
'test': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/test.txt', | ||
} | ||
|
||
def __init__(self, mode): | ||
assert mode.lower() in self._url, "Mode not valid" | ||
self.dir = get_download_dir() | ||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False): | ||
mode = mode.lower() | ||
assert mode in ['train', 'valid', 'test'], "Mode not valid." | ||
self.mode = mode | ||
# self.graphs = [] | ||
train_data = loadtxt(os.path.join( | ||
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64) | ||
if self.mode == 'train': | ||
self._load(train_data) | ||
elif self.mode == 'valid': | ||
val_data = loadtxt(os.path.join( | ||
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64) | ||
train_data[:, 3] = -1 | ||
self._load(np.concatenate([train_data, val_data], axis=0)) | ||
elif self.mode == 'test': | ||
val_data = loadtxt(os.path.join( | ||
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64) | ||
test_data = loadtxt(os.path.join( | ||
self.dir, 'GDELT', 'test.txt'), delimiter='\t').astype(np.int64) | ||
train_data[:, 3] = -1 | ||
val_data[:, 3] = -1 | ||
self._load(np.concatenate( | ||
[train_data, val_data, test_data], axis=0)) | ||
|
||
def _download(self): | ||
for dname in self._url: | ||
dpath = os.path.join( | ||
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1]) | ||
download(self._url[dname.lower()], path=dpath) | ||
|
||
@retry_method_with_fix(_download) | ||
def _load(self, data): | ||
self.num_nodes = 23033 | ||
_url = _get_dgl_url('dataset/gdelt.zip') | ||
super(GDELTDataset, self).__init__(name='GDELT', | ||
url=_url, | ||
raw_dir=raw_dir, | ||
force_reload=force_reload, | ||
verbose=verbose) | ||
|
||
def process(self): | ||
file_path = os.path.join(self.raw_path, self.mode + '.txt') | ||
self.data = loadtxt(file_path, delimiter='\t').astype(np.int64) | ||
|
||
# The source code is not released, but the paper indicates there're | ||
# totally 137 samples. The cutoff below has exactly 137 samples. | ||
self.data = data | ||
self.time_index = np.floor(data[:, 3]/15).astype(np.int64) | ||
self.start_time = self.time_index[self.time_index != -1].min() | ||
self.end_time = self.time_index.max() | ||
self.time_index = np.floor(self.data[:, 3] / 15).astype(np.int64) | ||
self._start_time = self.time_index.min() | ||
self._end_time = self.time_index.max() | ||
|
||
def has_cache(self): | ||
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') | ||
return os.path.exists(info_path) | ||
|
||
def save(self): | ||
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') | ||
save_info(info_path, {'data': self.data, | ||
'time_index': self.time_index, | ||
'start_time': self.start_time, | ||
'end_time': self.end_time}) | ||
|
||
def load(self): | ||
info_path = os.path.join(self.save_path, self.mode + '_info.pkl') | ||
info = load_info(info_path) | ||
self.data, self.time_index, self._start_time, self._end_time = \ | ||
info['data'], info['time_index'], info['start_time'], info['end_time'] | ||
|
||
@property | ||
def start_time(self): | ||
r""" Start time of events in the temporal graph | ||
def __getitem__(self, idx): | ||
if idx >= len(self) or idx < 0: | ||
Returns | ||
------- | ||
int | ||
""" | ||
return self._start_time | ||
|
||
@property | ||
def end_time(self): | ||
r""" End time of events in the temporal graph | ||
Returns | ||
------- | ||
int | ||
""" | ||
return self._end_time | ||
|
||
def __getitem__(self, t): | ||
r""" Get graph by with events before time `t + self.start_time` | ||
Parameters | ||
---------- | ||
t : int | ||
Time, its value must be in range [0, `self.end_time` - `self.start_time`] | ||
Returns | ||
------- | ||
dgl.DGLGraph | ||
graph structure and edge feature | ||
- edata['rel_type']: edge type | ||
""" | ||
if t >= len(self) or t < 0: | ||
raise IndexError("Index out of range") | ||
i = idx + self.start_time | ||
i = t + self.start_time | ||
row_mask = self.time_index <= i | ||
edges = self.data[row_mask][:, [0, 2]] | ||
rate = self.data[row_mask][:, 1] | ||
g = convert.graph((edges[:, 0], edges[:, 1])) | ||
g.edata['rel_type'] = rate.reshape(-1, 1) | ||
g = dgl_graph((edges[:, 0], edges[:, 1])) | ||
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) | ||
return g | ||
|
||
def __len__(self): | ||
return self.end_time - self.start_time + 1 | ||
|
||
@property | ||
def num_nodes(self): | ||
return 23033 | ||
r"""Number of graphs in the dataset""" | ||
return self._end_time - self._start_time + 1 | ||
|
||
@property | ||
def is_temporal(self): | ||
r""" Does the dataset contain temporal graphs | ||
Returns | ||
------- | ||
bool | ||
""" | ||
return True | ||
|
||
|
||
GDELT = GDELTDataset |