-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,232 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,41 @@ | ||
# de-simple | ||
Code coming soon. | ||
## Diachronic Embedding for Temporal Knowledge Graph Completion | ||
This repository contains code for the reprsentation proposed in [Diachronic Embedding for Temporal Knowledge Graph Completion](https://arxiv.org/pdf/1907.03143.pdf) paper. | ||
## Installation | ||
- Create a conda environment: | ||
``` | ||
$ conda create -n tkgc python=3.6 anaconda | ||
``` | ||
- Run | ||
``` | ||
$ source activate tkgc | ||
``` | ||
- Change directory to TKGC folder | ||
- Run | ||
``` | ||
$ pip install -r requirements.txt | ||
``` | ||
## How to use? | ||
After installing the requirements, run the following command to reproduce results for DE-SimplE: | ||
``` | ||
$ python main.py -dropout 0.4 -se_prop 0.68 -model DE-SimplE | ||
``` | ||
To reproduce the results for DE-DistMult and DE-TransE, specify **model** as DE-DistMult/DE-TransE as following. | ||
``` | ||
$ python main.py -dropout 0.4 -se_prop 0.36 -model DE-DistMult | ||
$ python main.py -dropout 0.4 -se_prop 0.36 -model DE-TransE | ||
``` | ||
## Citation | ||
If you use the codes, please cite the following paper: | ||
``` | ||
@article{goel2019diachronic, | ||
title={Diachronic Embedding for Temporal Knowledge Graph Completion}, | ||
author={Goel, Rishab and Kazemi, Seyed Mehran and Brubaker, Marcus and Poupart, Pascal}, | ||
journal={arXiv preprint arXiv:1907.03143}, | ||
year={2019} | ||
} | ||
``` | ||
## License | ||
Copyright (c) 2018-present, Royal Bank of Canada. | ||
All rights reserved. | ||
This source code is licensed under the license found in the | ||
LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) 2018-present, Royal Bank of Canada. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) 2018-present, Royal Bank of Canada. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
import random | ||
import math | ||
import copy | ||
import time | ||
import numpy as np | ||
from random import shuffle | ||
from scripts import shredFacts | ||
|
||
class Dataset: | ||
"""Implements the specified dataloader""" | ||
def __init__(self, | ||
ds_name): | ||
""" | ||
Params: | ||
ds_name : name of the dataset | ||
""" | ||
self.name = ds_name | ||
# self.ds_path = "<path-to-dataset>" + ds_name.lower() + "/" | ||
self.ds_path = "datasets/" + ds_name.lower() + "/" | ||
self.ent2id = {} | ||
self.rel2id = {} | ||
self.data = {"train": self.readFile(self.ds_path + "train.txt"), | ||
"valid": self.readFile(self.ds_path + "valid.txt"), | ||
"test": self.readFile(self.ds_path + "test.txt")} | ||
|
||
self.start_batch = 0 | ||
self.all_facts_as_tuples = None | ||
|
||
self.convertTimes() | ||
|
||
self.all_facts_as_tuples = set([tuple(d) for d in self.data["train"] + self.data["valid"] + self.data["test"]]) | ||
|
||
for spl in ["train", "valid", "test"]: | ||
self.data[spl] = np.array(self.data[spl]) | ||
|
||
def readFile(self, | ||
filename): | ||
|
||
with open(filename, "r") as f: | ||
data = f.readlines() | ||
|
||
facts = [] | ||
for line in data: | ||
elements = line.strip().split("\t") | ||
|
||
head_id = self.getEntID(elements[0]) | ||
rel_id = self.getRelID(elements[1]) | ||
tail_id = self.getEntID(elements[2]) | ||
timestamp = elements[3] | ||
|
||
facts.append([head_id, rel_id, tail_id, timestamp]) | ||
|
||
return facts | ||
|
||
|
||
def convertTimes(self): | ||
""" | ||
This function spits the timestamp in the day,date and time. | ||
""" | ||
for split in ["train", "valid", "test"]: | ||
for i, fact in enumerate(self.data[split]): | ||
fact_date = fact[-1] | ||
self.data[split][i] = self.data[split][i][:-1] | ||
date = list(map(float, fact_date.split("-"))) | ||
self.data[split][i] += date | ||
|
||
|
||
|
||
def numEnt(self): | ||
|
||
return len(self.ent2id) | ||
|
||
def numRel(self): | ||
|
||
return len(self.rel2id) | ||
|
||
|
||
def getEntID(self, | ||
ent_name): | ||
|
||
if ent_name in self.ent2id: | ||
return self.ent2id[ent_name] | ||
self.ent2id[ent_name] = len(self.ent2id) | ||
return self.ent2id[ent_name] | ||
|
||
def getRelID(self, rel_name): | ||
if rel_name in self.rel2id: | ||
return self.rel2id[rel_name] | ||
self.rel2id[rel_name] = len(self.rel2id) | ||
return self.rel2id[rel_name] | ||
|
||
|
||
def nextPosBatch(self, batch_size): | ||
if self.start_batch + batch_size > len(self.data["train"]): | ||
ret_facts = self.data["train"][self.start_batch : ] | ||
self.start_batch = 0 | ||
else: | ||
ret_facts = self.data["train"][self.start_batch : self.start_batch + batch_size] | ||
self.start_batch += batch_size | ||
return ret_facts | ||
|
||
|
||
def addNegFacts(self, bp_facts, neg_ratio): | ||
ex_per_pos = 2 * neg_ratio + 2 | ||
facts = np.repeat(np.copy(bp_facts), ex_per_pos, axis=0) | ||
for i in range(bp_facts.shape[0]): | ||
s1 = i * ex_per_pos + 1 | ||
e1 = s1 + neg_ratio | ||
s2 = e1 + 1 | ||
e2 = s2 + neg_ratio | ||
|
||
facts[s1:e1,0] = (facts[s1:e1,0] + np.random.randint(low=1, high=self.numEnt(), size=neg_ratio)) % self.numEnt() | ||
facts[s2:e2,2] = (facts[s2:e2,2] + np.random.randint(low=1, high=self.numEnt(), size=neg_ratio)) % self.numEnt() | ||
|
||
return facts | ||
|
||
def addNegFacts2(self, bp_facts, neg_ratio): | ||
pos_neg_group_size = 1 + neg_ratio | ||
facts1 = np.repeat(np.copy(bp_facts), pos_neg_group_size, axis=0) | ||
facts2 = np.copy(facts1) | ||
rand_nums1 = np.random.randint(low=1, high=self.numEnt(), size=facts1.shape[0]) | ||
rand_nums2 = np.random.randint(low=1, high=self.numEnt(), size=facts2.shape[0]) | ||
|
||
for i in range(facts1.shape[0] // pos_neg_group_size): | ||
rand_nums1[i * pos_neg_group_size] = 0 | ||
rand_nums2[i * pos_neg_group_size] = 0 | ||
|
||
facts1[:,0] = (facts1[:,0] + rand_nums1) % self.numEnt() | ||
facts2[:,2] = (facts2[:,2] + rand_nums2) % self.numEnt() | ||
return np.concatenate((facts1, facts2), axis=0) | ||
|
||
def nextBatch(self, batch_size, neg_ratio=1): | ||
bp_facts = self.nextPosBatch(batch_size) | ||
batch = shredFacts(self.addNegFacts2(bp_facts, neg_ratio)) | ||
return batch | ||
|
||
|
||
def wasLastBatch(self): | ||
return (self.start_batch == 0) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2018-present, Royal Bank of Canada. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
import torch.nn.functional as F | ||
from params import Params | ||
from dataset import Dataset | ||
|
||
class DE_DistMult(torch.nn.Module): | ||
''' | ||
Implements the DE_DistMult model in https://arxiv.org/abs/1907.03143 | ||
''' | ||
def __init__(self, dataset, params): | ||
|
||
super(DE_DistMult, self).__init__() | ||
self.dataset = dataset | ||
self.params = params | ||
|
||
# Creating static embeddings. | ||
self.ent_embs = nn.Embedding(dataset.numEnt(), params.s_emb_dim).cuda() | ||
self.rel_embs = nn.Embedding(dataset.numRel(), params.s_emb_dim+params.t_emb_dim).cuda() | ||
|
||
# Creating and initializing the temporal embeddings for the entities | ||
self.create_time_embedds() | ||
|
||
# Setting the non-linearity to be used for temporal part of the embedding | ||
self.time_nl = torch.sin | ||
|
||
nn.init.xavier_uniform_(self.ent_embs.weight) | ||
nn.init.xavier_uniform_(self.rel_embs.weight) | ||
|
||
|
||
def create_time_embedds(self): | ||
|
||
# frequency embeddings for the entities | ||
self.m_freq = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.d_freq = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.y_freq = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
|
||
nn.init.xavier_uniform_(self.m_freq.weight) | ||
nn.init.xavier_uniform_(self.d_freq.weight) | ||
nn.init.xavier_uniform_(self.y_freq.weight) | ||
|
||
# phi embeddings for the entities | ||
self.m_phi = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.d_phi = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.y_phi = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
|
||
nn.init.xavier_uniform_(self.m_phi.weight) | ||
nn.init.xavier_uniform_(self.d_phi.weight) | ||
nn.init.xavier_uniform_(self.y_phi.weight) | ||
|
||
# amplitude embeddings for the entities | ||
self.m_amp = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.d_amp = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
self.y_amp = nn.Embedding(self.dataset.numEnt(), self.params.t_emb_dim).cuda() | ||
|
||
nn.init.xavier_uniform_(self.m_amp.weight) | ||
nn.init.xavier_uniform_(self.d_amp.weight) | ||
nn.init.xavier_uniform_(self.y_amp.weight) | ||
|
||
|
||
def get_time_embedd(self, entities, year, month, day): | ||
|
||
y = self.y_amp(entities)*self.time_nl(self.y_freq(entities)*year + self.y_phi(entities)) | ||
m = self.m_amp(entities)*self.time_nl(self.m_freq(entities)*month + self.m_phi(entities)) | ||
d = self.d_amp(entities)*self.time_nl(self.d_freq(entities)*day + self.d_phi(entities)) | ||
|
||
return y+m+d | ||
|
||
def getEmbeddings(self, heads, rels, tails, years, months, days, intervals = None): | ||
years = years.view(-1,1) | ||
months = months.view(-1,1) | ||
days = days.view(-1,1) | ||
|
||
h,r,t = self.ent_embs(heads), self.rel_embs(rels), self.ent_embs(tails) | ||
h_t = self.get_time_embedd(heads, years, months, days) | ||
t_t = self.get_time_embedd(tails, years, months, days) | ||
|
||
h = torch.cat((h,h_t), 1) | ||
t = torch.cat((t,t_t), 1) | ||
return h,r,t | ||
|
||
def forward(self, heads, rels, tails, years, months, days): | ||
h_embs, r_embs, t_embs = self.getEmbeddings(heads, rels, tails, years, months, days) | ||
|
||
scores = (h_embs * r_embs) * t_embs | ||
scores = F.dropout(scores, p=self.params.dropout, training=self.training) | ||
scores = torch.sum(scores, dim=1) | ||
|
||
return scores | ||
|
Oops, something went wrong.