Skip to content

Commit

Permalink
[Dataset] GDELTDataset (dmlc#1911)
Browse files Browse the repository at this point in the history
* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c.

* gdelt dataset

* Update gdelt.py

Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
HuXiangkun and classicsong authored Aug 3, 2020
1 parent 06ea03d commit 73b9c6f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 70 deletions.
2 changes: 1 addition & 1 deletion python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT
from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset
Expand Down
208 changes: 139 additions & 69 deletions python/dgl/data/gdelt.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,167 @@
from scipy import io
""" GDELT dataset for temporal graph """
import numpy as np
import os
import datetime

from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from .. import convert
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_info, load_info, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F


class GDELT(object):
"""
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held anywhere
in Russia on a given day is collapsed to a single entry).
class GDELTDataset(DGLBuiltinDataset):
r"""GDELT dataset for event-based temporal graph
This Dataset consists of
events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity).
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held
anywhere in Russia on a given day is collapsed to a single entry).
This Dataset consists ofevents collected from 1/1/2018 to 1/31/2018
(15 minutes time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs` <https://arxiv.org/abs/1904.05530>
- `The Global Database of Events, Language, and Tone (GDELT) `
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>
Statistics
----------
Train examples: 2,304
Valid examples: 288
Test examples: 384
Parameters
------------
mode: str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test']
----------
mode : str
Must be one of ('train', 'valid', 'test'). Default: 'train'
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
start_time : int
Start time of the temporal graph
end_time : int
End time of the temporal graph
is_temporal : bool
Does the dataset contain temporal graphs
Examples
----------
>>> # get train, valid, test dataset
>>> train_data = GDELTDataset()
>>> valid_data = GDELTDataset(mode='valid')
>>> test_data = GDELTDataset(mode='test')
>>>
>>> # length of train set
>>> train_size = len(train_data)
>>>
>>> for g in train_data:
.... e_feat = g.edata['rel_type']
.... # your code here
....
>>>
"""
_url = {
'train': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/train.txt',
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/valid.txt',
'test': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/test.txt',
}

def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid."
self.mode = mode
# self.graphs = []
train_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
self._load(train_data)
elif self.mode == 'valid':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
self._load(np.concatenate([train_data, val_data], axis=0))
elif self.mode == 'test':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))

def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)

@retry_method_with_fix(_download)
def _load(self, data):
self.num_nodes = 23033
_url = _get_dgl_url('dataset/gdelt.zip')
super(GDELTDataset, self).__init__(name='GDELT',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)

def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt')
self.data = loadtxt(file_path, delimiter='\t').astype(np.int64)

# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
self.data = data
self.time_index = np.floor(data[:, 3]/15).astype(np.int64)
self.start_time = self.time_index[self.time_index != -1].min()
self.end_time = self.time_index.max()
self.time_index = np.floor(self.data[:, 3] / 15).astype(np.int64)
self._start_time = self.time_index.min()
self._end_time = self.time_index.max()

def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(info_path)

def save(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'data': self.data,
'time_index': self.time_index,
'start_time': self.start_time,
'end_time': self.end_time})

def load(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
info = load_info(info_path)
self.data, self.time_index, self._start_time, self._end_time = \
info['data'], info['time_index'], info['start_time'], info['end_time']

@property
def start_time(self):
r""" Start time of events in the temporal graph
def __getitem__(self, idx):
if idx >= len(self) or idx < 0:
Returns
-------
int
"""
return self._start_time

@property
def end_time(self):
r""" End time of events in the temporal graph
Returns
-------
int
"""
return self._end_time

def __getitem__(self, t):
r""" Get graph by with events before time `t + self.start_time`
Parameters
----------
t : int
Time, its value must be in range [0, `self.end_time` - `self.start_time`]
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
"""
if t >= len(self) or t < 0:
raise IndexError("Index out of range")
i = idx + self.start_time
i = t + self.start_time
row_mask = self.time_index <= i
edges = self.data[row_mask][:, [0, 2]]
rate = self.data[row_mask][:, 1]
g = convert.graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = rate.reshape(-1, 1)
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
return g

def __len__(self):
return self.end_time - self.start_time + 1

@property
def num_nodes(self):
return 23033
r"""Number of graphs in the dataset"""
return self._end_time - self._start_time + 1

@property
def is_temporal(self):
r""" Does the dataset contain temporal graphs
Returns
-------
bool
"""
return True


GDELT = GDELTDataset

0 comments on commit 73b9c6f

Please sign in to comment.