Skip to content

Commit

Permalink
readme updates
Browse files Browse the repository at this point in the history
  • Loading branch information
xwhan committed Jul 23, 2017
1 parent 85c1088 commit d2b160f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 54 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Reinforcement Learning for Knowledge Graph Reasoning
# Deep Reinforcement Learning for Knowledge Graph Reasoning
We study the problem of learning to reason in large scale knowledge graphs (KGs). More specifically, we describe a novel reinforcement learning framework for learning multi-hop relational paths: we use a policy-based agent with continuous states based on knowledge graph embeddings, which reasons in a KG vector-space by sampling the most promising relation to extend its path. In contrast to prior work, our approach includes a reward function that takes the **accuravy**, **diversity**, and **efficiency** into consideration. Experimentally, we show that our proposed method outperforms a path-ranking based algorithm and knowledge graph embedding methods on Freebase and Never-Ending Language Learning datasets.

## Access the dataset
Expand All @@ -22,7 +22,8 @@ Download the knowledge graph dataset [NELL-995](http://cs.ucsb.edu/~xwhan/datase
2. `kb_env_rl.txt`: we add inverse triples of all triples in `raw.kb`, this file is used as the KG for reasoning
3. `entity2vec.bern/relation2vec.bern`: transE embeddings to represent out RL states, can be trained using [TransX implementations by thunlp](https://github.com/thunlp/Fast-TransX)
4. `tasks/`: each task is a particular reasoning relation
* `tasks/${relation}/*.vec`: trained TransD(H) Embeddings
* `tasks/${relation}/*.vec`: trained TransH Embeddings
* `tasks/${relation}/*.vec_D`: trained TransD Embeddings
* `tasks/${relation}/*.bern`: trained TransR Embedding trained
* `tasks/${relation}/*.unif`: trained TransE Embeddings
* `tasks/${relation}/transX`: triples used to train the KB embeddings
Expand Down
94 changes: 43 additions & 51 deletions scripts/fact_prediction_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_features():
ent_vec_R = np.loadtxt(dataPath_ + '/entity2vec.bern')
rel_vec_R = np.loadtxt(dataPath_ + '/relation2vec.bern')
M = np.loadtxt(dataPath_ + '/A.bern')
M = M.reshape([-1,100,100])
M = M.reshape([-1,50,50])
relation_vec_R = rel_vec_R[relation2id[rel],:]
M_vec = M[relation2id[rel],:,:]

Expand Down Expand Up @@ -259,7 +259,6 @@ def get_features():
M = np.loadtxt(dataPath_ + '/A.vec')
M = M.reshape([rel_vec.shape[0],-1])


f = open(test_data_path)
test_data = f.readlines()
f.close()
Expand All @@ -275,63 +274,21 @@ def get_features():
label = 1 if line[-2] == '+' else 0
test_labels.append(label)


aps = []
query = test_pairs[0][0]
y_true = []
y_score = []

score_all = []

rel = relation.replace("_", ":")
d_r = np.expand_dims(rel_vec[relation2id[rel],:],1)
w_r = np.expand_dims(M[relation2id[rel],:],1)

for idx, sample in enumerate(test_pairs):
#print 'query node: ', sample[0], idx
if sample[0] == query:
h = np.expand_dims(ent_vec[entity2id[sample[0]],:],1)
t = np.expand_dims(ent_vec[entity2id[sample[1]],:],1)

h_ = h - np.matmul(w_r.transpose(), h)*w_r
t_ = t - np.matmul(w_r.transpose(), t)*w_r


score = -np.sum(np.square(h_ + d_r - t_))

score_all.append(score)
y_score.append(score)
y_true.append(test_labels[idx])
else:
query = sample[0]
count = zip(y_score, y_true)
count.sort(key = lambda x:x[0], reverse=True)
#print count
ranks = []
correct = 0
for idx_, item in enumerate(count):
if item[1] == 1:
correct += 1
ranks.append(correct/(1.0+idx_))
if len(ranks)==0:
ranks.append(0)
aps.append(np.mean(ranks))
y_true = []
y_score = []

h = np.expand_dims(ent_vec[entity2id[sample[0]],:],1)
t = np.expand_dims(ent_vec[entity2id[sample[1]],:],1)

h_ = h - np.matmul(w_r.transpose(), h)*w_r
t_ = t - np.matmul(w_r.transpose(), t)*w_r

h = np.expand_dims(ent_vec[entity2id[sample[0]],:],1)
t = np.expand_dims(ent_vec[entity2id[sample[1]],:],1)

score = -np.sum(np.square(h_ + d_r - t_))

score_all.append(score)
y_score.append(score)
y_true.append(test_labels[idx])
h_ = h - np.matmul(w_r.transpose(), h)*w_r
t_ = t - np.matmul(w_r.transpose(), t)*w_r

score = -np.sum(np.square(h_ + d_r - t_))
score_all.append(score)

score_label = zip(score_all, test_labels)
stats = sorted(score_label, key = lambda x:x[0], reverse=True)
Expand All @@ -343,5 +300,40 @@ def get_features():
correct += 1
ranks.append(correct/(1.0+idx))
ap4 = np.mean(ranks)
print 'TransX: ', ap4
print 'TransH: ', ap4

ent_vec_D = np.loadtxt(dataPath_ + '/entity2vec.vec_D')
rel_vec_D = np.loadtxt(dataPath_ + '/relation2vec.vec_D')
M_D = np.loadtxt(dataPath_ + '/A.vec_D')
ent_num = ent_vec_D.shape[0]
rel_num = rel_vec_D.shape[0]
rel_tran = M_D[0:rel_num,:]
ent_tran = M_D[rel_num:,:]
dim = ent_vec_D.shape[1]

rel_id = relation2id[rel]
r = np.expand_dims(rel_vec[rel_id,:], 1)
r_p = np.expand_dims(rel_tran[rel_id,:], 1)
scores_all_D = []
for idx, sample in enumerate(test_pairs):
h = np.expand_dims(ent_vec_D[entity2id[sample[0]],:], 1)
h_p = np.expand_dims(ent_tran[entity2id[sample[0]],:], 1)
t = np.expand_dims(ent_vec_D[entity2id[sample[1]],:], 1)
t_p = np.expand_dims(ent_tran[entity2id[sample[1]],:], 1)
M_rh = np.matmul(r_p, h_p.transpose()) + np.identity(dim)
M_rt = np.matmul(r_p, t_p.transpose()) + np.identity(dim)
score = - np.sum(np.square(M_rh.dot(h) + r - M_rt.dot(t)))
scores_all_D.append(score)

score_label = zip(scores_all_D, test_labels)
stats = sorted(score_label, key = lambda x:x[0], reverse=True)

correct = 0
ranks = []
for idx, item in enumerate(stats):
if item[1] == 1:
correct += 1
ranks.append(correct/(1.0+idx))
ap5 = np.mean(ranks)
print 'TransD: ', ap5

2 changes: 1 addition & 1 deletion scripts/transR_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ent_vec = np.loadtxt(dataPath_ + '/entity2vec.bern')
rel_vec = np.loadtxt(dataPath_ + '/relation2vec.bern')
M = np.loadtxt(dataPath_ + '/A.bern')
M = M.reshape([-1,100,100])
M = M.reshape([-1,50,50])

f = open(test_data_path)
test_data = f.readlines()
Expand Down

0 comments on commit d2b160f

Please sign in to comment.