diff --git a/rl/scripts/enjoy.py b/rl/scripts/enjoy.py index e2b8606b..171ed20c 100644 --- a/rl/scripts/enjoy.py +++ b/rl/scripts/enjoy.py @@ -2,7 +2,7 @@ import argparse import gym -import levels +from babyai import levels import time import torch diff --git a/rl/scripts/evaluate.py b/rl/scripts/evaluate.py index fc8cd979..0445ab26 100644 --- a/rl/scripts/evaluate.py +++ b/rl/scripts/evaluate.py @@ -2,7 +2,7 @@ import argparse import gym -import levels +from babyai import levels import time import datetime import torch diff --git a/rl/scripts/make_agent_demos.py b/rl/scripts/make_agent_demos.py index 54fd574b..8e0e8f1e 100644 --- a/rl/scripts/make_agent_demos.py +++ b/rl/scripts/make_agent_demos.py @@ -2,7 +2,7 @@ import argparse import gym -import levels +from babyai import levels import torch_rl import utils diff --git a/rl/scripts/make_human_demos.py b/rl/scripts/make_human_demos.py index 737f237e..505ca772 100644 --- a/rl/scripts/make_human_demos.py +++ b/rl/scripts/make_human_demos.py @@ -6,7 +6,7 @@ import argparse import datetime import gym -import levels +from babyai import levels from PyQt5.QtCore import Qt, QTimer from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QInputDialog from PyQt5.QtWidgets import QLabel, QTextEdit, QFrame diff --git a/rl/scripts/train_il.py b/rl/scripts/train_il.py index 3ce4782c..7ec9d98a 100644 --- a/rl/scripts/train_il.py +++ b/rl/scripts/train_il.py @@ -3,7 +3,7 @@ import argparse import gym -import levels +from babyai import levels import time import datetime import numpy diff --git a/rl/scripts/train_rl.py b/rl/scripts/train_rl.py index 98781521..4dc54c1a 100644 --- a/rl/scripts/train_rl.py +++ b/rl/scripts/train_rl.py @@ -2,7 +2,7 @@ import argparse import gym -import levels +from babyai import levels import time import datetime import sys diff --git a/rl/scripts/train_wd.py b/rl/scripts/train_wd.py index eac569ba..ef5211fc 100644 --- a/rl/scripts/train_wd.py +++ b/rl/scripts/train_wd.py @@ -2,7 +2,7 @@ import argparse import gym -import levels +from babyai import levels import time import datetime import numpy as np @@ -122,34 +122,45 @@ def calculate_values(): for demo in demos: flat_demos += demo value_inds.append(value_inds[-1] + len(demo)) + + # Value inds are pointing at the last episode of each of the observation flat_demos = np.array(flat_demos) value_inds = value_inds[:-1] value_inds = [index-1 for index in value_inds][1:] +[len(flat_demos)-1] + # Value array with the length of flat_demos values = np.zeros([len(flat_demos)],dtype=np.float64) reward, done = flat_demos[:,2], flat_demos[:,3] + # Reshaping the reward reward = utils.reshape_reward(None,None,reward,None) + # Value for last episodes = reward at last episodes values[value_inds]= reward[value_inds] + # last value keeps the values corresponding to last visited states last_value = values[value_inds] + while True: value_inds = [index-1 for index in value_inds] if value_inds[0] == -1: break done_step = done[value_inds] + # Removing indices of finished episodes value_inds = value_inds[:len(value_inds)-sum(done_step)] last_value = last_value[:len(last_value)-sum(done_step)] - + + # Calculating value of the states using value of previous states values[value_inds]= reward[value_inds] + args.discount*last_value last_value = values[value_inds] + # Appending values to corresponding demos flat_demos = [np.append(flat_demos[i],[values[i],]) for i in range(len(flat_demos))] new_demos = [] offset = 0 + # Reconstructing demos from flat_demos for demo in demos: new_demos.append(flat_demos[offset:offset+len(demo)]) offset += len(demo) @@ -201,12 +212,15 @@ def run_epoch_recurrence(): batch_size = args.batch_size offset = 0 assert len(demos) % batch_size == 0 + + # Log dictionary log = {"entropy": [],"value_loss": [],"policy_loss": []} for batch_index in range(len(demos)//batch_size): batch = demos[offset:offset+batch_size] batch.sort(key=len,reverse=True) + # Constructing flat batch and indices pointing to start of each demonstration flat_batch = [] inds = [0] for demo in batch: @@ -215,39 +229,48 @@ def run_epoch_recurrence(): flat_batch = np.array(flat_batch) inds = inds[:-1] - + # Observations, true action, values and done for each of the stored demostration obss, action_true, values, done = flat_batch[:,0], flat_batch[:,1], flat_batch[:,4], flat_batch[:,3] action_true = torch.tensor([action for action in action_true],device=device,dtype=torch.float) values = torch.tensor([value for value in values],device=device,dtype=torch.float) - + # Memory to be stored memories = torch.zeros([len(flat_batch),acmodel.memory_size], device=device) memory = torch.zeros([batch_size,acmodel.memory_size], device=device) + # time_step_inds to be used for calculating the memory for each observation in flat_batch time_step = 0 time_step_inds = inds + # Loop terminates when every observation in the flat_batch has been handled while True: + # taking observations and done located at time_step_inds obs = obss[time_step_inds] done_step = done[time_step_inds] preprocessed_obs = obss_preprocessor(obs, device=device) with torch.no_grad(): if args.model_mem: + # taking the memory till the length of time_step_inds, as demos beyond that have already finished _, _, new_memory = acmodel(preprocessed_obs, memory[:len(time_step_inds),:]) if args.model_mem: for i in range(len(time_step_inds)): + # Copying to the memories at the corresponding locations memories[time_step_inds[i],:] = memory[i,:] memory[:len(time_step_inds),:] = new_memory + # Updating time_step_inds, by removing those indices corresponding to which the demonstrations have finished time_step_inds = time_step_inds[:len(time_step_inds)-sum(done_step)] if len(time_step_inds) == 0: break + # Incrementing the remaining indices time_step_inds = [index+1 for index in time_step_inds] + # Here, actual backprop upto args.recurrence happens while True: memory = memories[inds,:] final_loss = 0 + # taking observations till that recurrence, and computing backprop on them for i in range(args.recurrence): obs = obss[inds] preprocessed_obs = obss_preprocessor(obs, device=device) @@ -280,9 +303,8 @@ def run_epoch_recurrence(): if len(inds) == 0: break - offset += batch_size - return log + return log total_start_time = time.time()