Skip to content

Commit

Permalink
Merge branch 'rl' of https://github.com/maximecb/baby-ai-game into rl
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed May 29, 2018
2 parents 50c00cb + fd00b32 commit d53e941
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 12 deletions.
2 changes: 1 addition & 1 deletion rl/scripts/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import gym
import levels
from babyai import levels
import time
import torch

Expand Down
2 changes: 1 addition & 1 deletion rl/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import gym
import levels
from babyai import levels
import time
import datetime
import torch
Expand Down
2 changes: 1 addition & 1 deletion rl/scripts/make_agent_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import gym
import levels
from babyai import levels
import torch_rl

import utils
Expand Down
2 changes: 1 addition & 1 deletion rl/scripts/make_human_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import datetime
import gym
import levels
from babyai import levels
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QInputDialog
from PyQt5.QtWidgets import QLabel, QTextEdit, QFrame
Expand Down
2 changes: 1 addition & 1 deletion rl/scripts/train_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import argparse
import gym
import levels
from babyai import levels
import time
import datetime
import numpy
Expand Down
2 changes: 1 addition & 1 deletion rl/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import gym
import levels
from babyai import levels
import time
import datetime
import sys
Expand Down
34 changes: 28 additions & 6 deletions rl/scripts/train_wd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import gym
import levels
from babyai import levels
import time
import datetime
import numpy as np
Expand Down Expand Up @@ -122,34 +122,45 @@ def calculate_values():
for demo in demos:
flat_demos += demo
value_inds.append(value_inds[-1] + len(demo))

# Value inds are pointing at the last episode of each of the observation
flat_demos = np.array(flat_demos)
value_inds = value_inds[:-1]
value_inds = [index-1 for index in value_inds][1:] +[len(flat_demos)-1]

# Value array with the length of flat_demos
values = np.zeros([len(flat_demos)],dtype=np.float64)

reward, done = flat_demos[:,2], flat_demos[:,3]

# Reshaping the reward
reward = utils.reshape_reward(None,None,reward,None)

# Value for last episodes = reward at last episodes
values[value_inds]= reward[value_inds]
# last value keeps the values corresponding to last visited states
last_value = values[value_inds]


while True:
value_inds = [index-1 for index in value_inds]
if value_inds[0] == -1:
break
done_step = done[value_inds]
# Removing indices of finished episodes
value_inds = value_inds[:len(value_inds)-sum(done_step)]
last_value = last_value[:len(last_value)-sum(done_step)]


# Calculating value of the states using value of previous states
values[value_inds]= reward[value_inds] + args.discount*last_value
last_value = values[value_inds]

# Appending values to corresponding demos
flat_demos = [np.append(flat_demos[i],[values[i],]) for i in range(len(flat_demos))]
new_demos = []
offset = 0

# Reconstructing demos from flat_demos
for demo in demos:
new_demos.append(flat_demos[offset:offset+len(demo)])
offset += len(demo)
Expand Down Expand Up @@ -201,12 +212,15 @@ def run_epoch_recurrence():
batch_size = args.batch_size
offset = 0
assert len(demos) % batch_size == 0

# Log dictionary
log = {"entropy": [],"value_loss": [],"policy_loss": []}

for batch_index in range(len(demos)//batch_size):
batch = demos[offset:offset+batch_size]
batch.sort(key=len,reverse=True)

# Constructing flat batch and indices pointing to start of each demonstration
flat_batch = []
inds = [0]
for demo in batch:
Expand All @@ -215,39 +229,48 @@ def run_epoch_recurrence():
flat_batch = np.array(flat_batch)
inds = inds[:-1]


# Observations, true action, values and done for each of the stored demostration
obss, action_true, values, done = flat_batch[:,0], flat_batch[:,1], flat_batch[:,4], flat_batch[:,3]
action_true = torch.tensor([action for action in action_true],device=device,dtype=torch.float)
values = torch.tensor([value for value in values],device=device,dtype=torch.float)


# Memory to be stored
memories = torch.zeros([len(flat_batch),acmodel.memory_size], device=device)
memory = torch.zeros([batch_size,acmodel.memory_size], device=device)

# time_step_inds to be used for calculating the memory for each observation in flat_batch
time_step = 0
time_step_inds = inds

# Loop terminates when every observation in the flat_batch has been handled
while True:
# taking observations and done located at time_step_inds
obs = obss[time_step_inds]
done_step = done[time_step_inds]
preprocessed_obs = obss_preprocessor(obs, device=device)
with torch.no_grad():
if args.model_mem:
# taking the memory till the length of time_step_inds, as demos beyond that have already finished
_, _, new_memory = acmodel(preprocessed_obs, memory[:len(time_step_inds),:])

if args.model_mem:
for i in range(len(time_step_inds)):
# Copying to the memories at the corresponding locations
memories[time_step_inds[i],:] = memory[i,:]
memory[:len(time_step_inds),:] = new_memory
# Updating time_step_inds, by removing those indices corresponding to which the demonstrations have finished
time_step_inds = time_step_inds[:len(time_step_inds)-sum(done_step)]
if len(time_step_inds) == 0:
break
# Incrementing the remaining indices
time_step_inds = [index+1 for index in time_step_inds]


# Here, actual backprop upto args.recurrence happens
while True:
memory = memories[inds,:]
final_loss = 0
# taking observations till that recurrence, and computing backprop on them
for i in range(args.recurrence):
obs = obss[inds]
preprocessed_obs = obss_preprocessor(obs, device=device)
Expand Down Expand Up @@ -280,9 +303,8 @@ def run_epoch_recurrence():

if len(inds) == 0:
break

offset += batch_size
return log
return log

total_start_time = time.time()

Expand Down

0 comments on commit d53e941

Please sign in to comment.