Skip to content

Commit

Permalink
Port to Python3
Browse files Browse the repository at this point in the history
  • Loading branch information
fridex committed Jan 2, 2020
1 parent e9981de commit c3a571d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 50 deletions.
7 changes: 5 additions & 2 deletions scripts/BFS/BFS.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from Queue import Queue
try:
from Queue import Queue
except ImportError:
from queue import Queue
import random

def BFS(kb, entity1, entity2):
Expand Down Expand Up @@ -53,4 +56,4 @@ def __str__(self):
res = ""
for entity, status in self.entities.iteritems():
res += entity + "[{},{},{}]".format(status[0],status[1],status[2])
return res
return res
6 changes: 3 additions & 3 deletions scripts/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def interact(self, state, action):
path = random.choice(choices)
self.path.append(path[2] + ' -> ' + path[1])
self.path_relations.append(path[2])
# print 'Find a valid step', path
# print 'Action index', action
# print('Find a valid step', path)
# print('Action index', action)
self.die = 0
new_pos = self.entity2id_[path[1]]
reward = 0
new_state = [new_pos, target_pos, self.die]

if new_pos == target_pos:
print 'Find a path:', self.path
print('Find a path:', self.path)
done = 1
reward = 0
new_state = None
Expand Down
51 changes: 26 additions & 25 deletions scripts/policy_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import collections
Expand Down Expand Up @@ -56,8 +57,8 @@ def REINFORCE(training_pairs, policy_nn, num_episodes):

for i_episode in range(num_episodes):
start = time.time()
print 'Episode %d' % i_episode
print 'Training sample: ', train[i_episode][:-1]
print('Episode %d' % i_episode)
print('Training sample: ', train[i_episode][:-1])

env = Env(dataPath, train[i_episode])

Expand Down Expand Up @@ -87,17 +88,17 @@ def REINFORCE(training_pairs, policy_nn, num_episodes):

# Discourage the agent when it choose an invalid step
if len(state_batch_negative) != 0:
print 'Penalty to invalid steps:', len(state_batch_negative)
print('Penalty to invalid steps:', len(state_batch_negative))
policy_nn.update(np.reshape(state_batch_negative, (-1, state_dim)), -0.05, action_batch_negative)

print '----- FINAL PATH -----'
print '\t'.join(env.path)
print 'PATH LENGTH', len(env.path)
print '----- FINAL PATH -----'
print('----- FINAL PATH -----')
print('\t'.join(env.path))
print('PATH LENGTH', len(env.path))
print('----- FINAL PATH -----')

# If the agent success, do one optimization
if done == 1:
print 'Success'
print('Success')

path_found_entity.append(path_clean(' -> '.join(env.path)))

Expand Down Expand Up @@ -139,7 +140,7 @@ def REINFORCE(training_pairs, policy_nn, num_episodes):
action_batch.append(transition.action)
policy_nn.update(np.reshape(state_batch, (-1,state_dim)), total_reward, action_batch)

print 'Failed, Do one teacher guideline'
print('Failed, Do one teacher guideline')
try:
good_episodes = teacher(sample[0], sample[1], 1, env, graphpath)
for item in good_episodes:
Expand All @@ -152,10 +153,10 @@ def REINFORCE(training_pairs, policy_nn, num_episodes):
policy_nn.update(np.squeeze(teacher_state_batch), 1, teacher_action_batch)

except Exception as e:
print 'Teacher guideline failed'
print 'Episode time: ', time.time() - start
print '\n'
print 'Success percentage:', success/num_episodes
print('Teacher guideline failed')
print('Episode time: ', time.time() - start)
print('\n')
print('Success percentage:', success/num_episodes)

for path in path_found_entity:
rel_ent = path.split(' -> ')
Expand All @@ -172,12 +173,12 @@ def REINFORCE(training_pairs, policy_nn, num_episodes):
for item in relation_path_stats:
f.write(item[0]+'\t'+str(item[1])+'\n')
f.close()
print 'Path stats saved'
print('Path stats saved')

return

def retrain():
print 'Start retraining'
print('Start retraining')
tf.reset_default_graph()
policy_network = PolicyNetwork(scope = 'supervised_policy')

Expand All @@ -188,13 +189,13 @@ def retrain():
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'models/policy_supervised_' + relation)
print "sl_policy restored"
print("sl_policy restored")
episodes = len(training_pairs)
if episodes > 300:
episodes = 300
REINFORCE(training_pairs, policy_network, episodes)
saver.save(sess, 'models/policy_retrained' + relation)
print 'Retrained model saved'
print('Retrained model saved')

def test():
tf.reset_default_graph()
Expand All @@ -216,13 +217,13 @@ def test():

with tf.Session() as sess:
saver.restore(sess, 'models/policy_retrained' + relation)
print 'Model reloaded'
print('Model reloaded')

if test_num > 500:
test_num = 500

for episode in xrange(test_num):
print 'Test sample %d: %s' % (episode,test_data[episode][:-1])
for episode in range(test_num):
print('Test sample %d: %s' % (episode,test_data[episode][:-1]))
env = Env(dataPath, test_data[episode])
sample = test_data[episode].split()
state_idx = [env.entity2id_[sample[0]], env.entity2id_[sample[1]], 0]
Expand All @@ -243,11 +244,11 @@ def test():
if done or t == max_steps_test:
if done:
success += 1
print "Success\n"
print("Success\n")
path = path_clean(' -> '.join(env.path))
path_found.append(path)
else:
print 'Episode ends due to step limit\n'
print('Episode ends due to step limit\n')
break
state_idx = new_state

Expand All @@ -258,7 +259,7 @@ def test():
path_found_embedding = np.reshape(path_found_embedding, (-1,embedding_dim))
cos_sim = cosine_similarity(path_found_embedding, curr_path_embedding)
diverse_reward = -np.mean(cos_sim)
print 'diverse_reward', diverse_reward
print('diverse_reward', diverse_reward)
#total_reward = 0.1*global_reward + 0.8*length_reward + 0.1*diverse_reward
state_batch = []
action_batch = []
Expand Down Expand Up @@ -289,13 +290,13 @@ def test():
ranking_path.append((path, length))

ranking_path = sorted(ranking_path, key = lambda x:x[1])
print 'Success persentage:', success/test_num
print('Success persentage:', success/test_num)

f = open(dataPath + 'tasks/' + relation + '/' + 'path_to_use.txt', 'w')
for item in ranking_path:
f.write(item[0] + '\n')
f.close()
print 'path to use saved'
print('path to use saved')
return

if __name__ == "__main__":
Expand Down
25 changes: 13 additions & 12 deletions scripts/sl_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
from itertools import count
Expand Down Expand Up @@ -59,17 +60,17 @@ def train():
else:
num_episodes = num_samples

for episode in xrange(num_samples):
print "Episode %d" % episode
print 'Training Sample:', train_data[episode%num_samples][:-1]
for episode in range(num_samples):
print("Episode %d" % episode)
print('Training Sample:', train_data[episode%num_samples][:-1])

env = Env(dataPath, train_data[episode%num_samples])
sample = train_data[episode%num_samples].split()

try:
good_episodes = teacher(sample[0], sample[1], 5, env, graphpath)
except Exception as e:
print 'Cannot find a path'
print('Cannot find a path')
continue

for item in good_episodes:
Expand All @@ -83,7 +84,7 @@ def train():
policy_nn.update(state_batch, action_batch)

saver.save(sess, 'models/policy_supervised_' + relation)
print 'Model saved'
print('Model saved')


def test(test_episodes):
Expand All @@ -97,16 +98,16 @@ def test(test_episodes):
test_num = len(test_data)

test_data = test_data[-test_episodes:]
print len(test_data)
print(len(test_data))

success = 0

saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'models/policy_supervised_'+ relation)
print 'Model reloaded'
for episode in xrange(len(test_data)):
print 'Test sample %d: %s' % (episode,test_data[episode][:-1])
print('Model reloaded')
for episode in range(len(test_data)):
print('Test sample %d: %s' % (episode,test_data[episode][:-1]))
env = Env(dataPath, test_data[episode])
sample = test_data[episode].split()
state_idx = [env.entity2id_[sample[0]], env.entity2id_[sample[1]], 0]
Expand All @@ -117,13 +118,13 @@ def test(test_episodes):
reward, new_state, done = env.interact(state_idx, action_chosen)
if done or t == max_steps_test:
if done:
print 'Success'
print('Success')
success += 1
print 'Episode ends\n'
print('Episode ends\n')
break
state_idx = new_state

print 'Success persentage:', success/test_episodes
print('Success persentage:', success/test_episodes)

if __name__ == "__main__":
train()
Expand Down
17 changes: 9 additions & 8 deletions scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from __future__ import print_function
import random
from collections import namedtuple, Counter
import numpy as np
Expand Down Expand Up @@ -42,26 +43,26 @@ def teacher(e1, e2, num_paths, env, path = None):
intermediates = kb.pickRandomIntermediatesBetween(e1, e2, num_paths)
res_entity_lists = []
res_path_lists = []
for i in xrange(num_paths):
for i in range(num_paths):
suc1, entity_list1, path_list1 = BFS(kb, e1, intermediates[i])
suc2, entity_list2, path_list2 = BFS(kb, intermediates[i], e2)
if suc1 and suc2:
res_entity_lists.append(entity_list1 + entity_list2[1:])
res_path_lists.append(path_list1 + path_list2)
print 'BFS found paths:', len(res_path_lists)
print('BFS found paths:', len(res_path_lists))

# ---------- clean the path --------
res_entity_lists_new = []
res_path_lists_new = []
for entities, relations in zip(res_entity_lists, res_path_lists):
rel_ents = []
for i in xrange(len(entities)+len(relations)):
for i in range(len(entities)+len(relations)):
if i%2 == 0:
rel_ents.append(entities[int(i/2)])
else:
rel_ents.append(relations[int(i/2)])

#print rel_ents
#print(rel_ents)

entity_stats = Counter(entities).items()
duplicate_ents = [item for item in entity_stats if item[1]!=1]
Expand All @@ -84,14 +85,14 @@ def teacher(e1, e2, num_paths, env, path = None):
res_entity_lists_new.append(entities_new)
res_path_lists_new.append(relations_new)

print res_entity_lists_new
print res_path_lists_new
print(res_entity_lists_new)
print(res_path_lists_new)

good_episodes = []
targetID = env.entity2id_[e2]
for path in zip(res_entity_lists_new, res_path_lists_new):
good_episode = []
for i in xrange(len(path[0]) -1):
for i in range(len(path[0]) -1):
currID = env.entity2id_[path[0][i]]
nextID = env.entity2id_[path[0][i+1]]
state_curr = [currID, targetID, 0]
Expand Down Expand Up @@ -127,7 +128,7 @@ def prob_norm(probs):
return probs/sum(probs)

if __name__ == '__main__':
print prob_norm(np.array([1,1,1]))
print(prob_norm(np.array([1,1,1])))
#path_clean('/common/topic/webpage./common/webpage/category -> /m/08mbj5d -> /common/topic/webpage./common/webpage/category_inv -> /m/01d34b -> /common/topic/webpage./common/webpage/category -> /m/08mbj5d -> /common/topic/webpage./common/webpage/category_inv -> /m/0lfyx -> /common/topic/webpage./common/webpage/category -> /m/08mbj5d -> /common/topic/webpage./common/webpage/category_inv -> /m/01y67v -> /common/topic/webpage./common/webpage/category -> /m/08mbj5d -> /common/topic/webpage./common/webpage/category_inv -> /m/028qyn -> /people/person/nationality -> /m/09c7w0')


Expand Down

0 comments on commit c3a571d

Please sign in to comment.