Skip to content

Commit

Permalink
Auxiliary supervised losses (#159)
Browse files Browse the repository at this point in the history
* fix issue with loading demos when making demos

* reindentation + small bug fix

* - Added head for extra binary information that can be used for an auxiliary supervised loss
- PEP8 reformatting

* Line too long - changed

* Modified how experiences are collected when there are extra binary information to be used from the environment

* A few comments that were super helpful to me. I guess this commit doesn't need to be in the PR

* Taking into account the extra binary information to evaluate a supervised loss

* Small modifications to the model to output extra_logits (no need for sigmoid layer)

* - Added the possibility of specifying how many extra binary information to use from the environment (they have to specified in the `info` part of the gym step function).
- Logging of the corresponding supervised loss and accuracy

* clearer help for an argparse argument

* The environment yields, at each step, if the new state is already visited or not

* supervised loss coeff can be a float, and not necessarily an int

* fix a bug at evaluation time due to extra outputs of the model when there is an auxiliary loss

* quick hack to show the supervised loss coefficient in the model name for easier comparison

* - Reseeding after initializing the model to make sure to get consistent results.
- This commit doesn't need to be in the PR.

* typo

* Defining the extra info head after the actor and the critic so that the initialization process makes the results consistent between when we're not using extra info and when we're using it with a supervised loss coef of 0.

* Log total loss

* small bug fix

* typo

* added logging of prevalence in the supervised auxiliary task for debugging/understanding

* default extra binary info to False for retro compatibility

* added more binary info

* - fixed bug in enjoy
- made enjoy and evaluate compatible with the extra binary info setting

* - reuse previous deleted normalization of weights
- define as many extra_heads as passed to the model through a dictionary
- define an extra_predictions dictionary to be returned
- always return a dictionary in the forward model to avoid too many conditionals in scripts that use the model

* use extra-info as a list argument containing the names of the extra info wanted from the environment

* Because acmodel.forward returns a dictionary. This means that at each call, we should change the containing variable

* return extra information at each step

* - collect the experiments the right way
- update the parameters in the presence of extra info for supervised aux tasks
- adequate logging
- change of model

* change-list

* Use ModuleDict instead of dict. !! REQUIRES pytorch 0.4.1 !!

* Add a new aux loss - requires a small change in minigrid

* reintroduce the prevalences

* stop using numbers to check for presence of objects in observation

* factorization

* docstring

* removed unnecessary argument from ModelAgent

* fix bug introduced in bf10a286a89f8e15ee46d7cb7d41f374601dda28

* fix logging issue

* - add a conditional in evaluation of grad norm because sometimes we use different model origins
- add the option of using a pre-trained model and have the fine-tuned version of it saved elsewhere (otherwise, one cannot use the same pre-trained model to finetune 2 different models in parallel)

* allows using extra heads even for pre-trained models

* fix small bug with 'continuous' type

* fix small bug with 'continuous' type

* change comment - again, doesn't have to be merged

* use a new class instead of overloading the ppo file

* model refactorization

* more refactorization

* update requirement of pytorch version

* add some comments for the classes introduced

* bug fix

* - Use a wrapper for supervised auxiliary losses
- added binary info: does the agent think they did the same action as would the bot
- bugfix in rl/utils/supervised_losses

* comment on function

* move function to wrapper

* - Use different wrappers for each auxiliary task
- Rename extra info to aux info

* rename extra info to aux info

* revert file commited by mistake

* remove float() when dealing with binary information

* rename args
  • Loading branch information
saleml authored and maximecb committed Aug 30, 2018
1 parent df4e7c3 commit c694068
Show file tree
Hide file tree
Showing 14 changed files with 585 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Requirements:
- OpenAI gym
- NumPy
- PyQT5
- PyTorch 0.4+
- PyTorch 0.4.1+

Start by manually installing PyTorch. See the [PyTorch website](http://pytorch.org/)
for installation instructions specific to your platform.
Expand Down
8 changes: 6 additions & 2 deletions babyai/algos/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def run_epoch_recurrence_one_batch(self, batch, is_training=False):
preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
with torch.no_grad():
# taking the memory till the length of time_step_inds, as demos beyond that have already finished
_, _, new_memory = self.acmodel(preprocessed_obs, memory[:len(inds), :])
new_memory = self.acmodel(preprocessed_obs, memory[:len(inds), :])['memory']

for i in range(len(inds)):
# Copying to the memories at the corresponding locations
Expand Down Expand Up @@ -189,7 +189,11 @@ def run_epoch_recurrence_one_batch(self, batch, is_training=False):
preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
action_step = action_true[indexes]
mask_step = mask[indexes]
dist, value, memory = self.acmodel(preprocessed_obs, memory * mask_step)
model_results = self.acmodel(preprocessed_obs, memory * mask_step)
dist = model_results['dist']
value = model_results['value']
memory = model_results['memory']

entropy = dist.entropy().mean()
policy_loss = -dist.log_prob(action_step).mean()
loss = policy_loss - self.args.entropy_coef * entropy
Expand Down
195 changes: 195 additions & 0 deletions babyai/levels/supervised_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import gym
from babyai.agents.bot import Bot
from gym_minigrid.minigrid import OBJECT_TO_IDX, Grid
from .verifier import *


def wrap_env(env, aux_info):
'''
helper function that callss the defined wrappers depending on the what information is required
'''
if 'seen_state' in aux_info:
env = SeenStateWrapper(env)
if 'visit_proportion' in aux_info:
env = VisitProportionWrapper(env)
if 'see_door' in aux_info:
env = SeeDoorWrapper(env)
if 'see_obj' in aux_info:
env = SeeObjWrapper(env)
if 'in_front_of_what' in aux_info:
env = InForntOfWhatWrapper(env)
if 'obj_in_instr' in aux_info:
env = ObjInInstrWrapper(env)
if 'bot_action' in aux_info:
env = BotActionWrapper(env)
return env


class SeenStateWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to whether
the new state is already visited or not
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)

# Define a set of seen states. A state is represent by a tuple ((x, y), direction)
self.seen_states = set()

# Append the current state to the seen states
# The state is defined in the reset function of the MiniGridEnv class
self.seen_states.add((tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir))

return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)

if (tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir) in self.seen_states:
seen_state = True
else:
self.seen_states.add((tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir))
seen_state = False

info['seen_state'] = seen_state

return obs, reward, done, info


class VisitProportionWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to the number of times
the new state has been visited before, divided by the total number of steps
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)

# Define a dict of seen states and number of times seen. A state is represent by a tuple ((x, y), direction)
self.seen_states_dict = dict()

# Append the current state to the seen states
# The state is defined in the reset function of the MiniGridEnv class
self.seen_states_dict[(tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir)] = 1

# Instantiate a counter of total steps
self.total_steps = 0

return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)

self.total_steps += 1
if (tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir) in self.seen_states_dict:
self.seen_states_dict[(tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir)] += 1
else:
self.seen_states_dict[(tuple(self.env.unwrapped.agent_pos), self.env.unwrapped.agent_dir)] = 1

info['visit_proportion'] = ((self.seen_states_dict[(tuple(self.env.unwrapped.agent_pos),
self.env.unwrapped.agent_dir)]
- 1) / self.total_steps)

return obs, reward, done, info


class SeeDoorWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to whether
the current observation contains a door, locked or not
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)
info['see_door'] = (None, 'door') in Grid.decode(obs['image'])
return obs, reward, done, info


class SeeObjWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to whether
the current observation contains a key, ball, or box
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)
info['see_obj'] = any([obj in Grid.decode(obs['image']) for obj in
((None, 'key'), (None, 'ball'), (None, 'box'))
])
return obs, reward, done, info


class InForntOfWhatWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to which of
empty cell/wall/door/key/box/ball is in the cell right in front of the agent
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)
cell_in_front = self.env.unwrapped.grid.get(*self.env.unwrapped.front_pos)
info['in_front_of_what'] = OBJECT_TO_IDX[cell_in_front.type] if cell_in_front else 0 # int 0--8
return obs, reward, done, info


class ObjInInstrWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to whether an object
described in the instruction appears in the current observation
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return obs

def obj_in_mission(self, instr):
if isinstance(instr, PutNextInstr):
return [(instr.desc_fixed.color, instr.desc_fixed.type),
(instr.desc_move.color, instr.desc_move.type)]
if isinstance(instr, SeqInstr):
return self.obj_in_mission(instr.instr_a) + self.obj_in_mission(instr.instr_b)
else:
return [(instr.desc.color, instr.desc.type)]

def step(self, action):
obs, reward, done, info = self.env.step(action)
info['obj_in_instr'] = any([obj in Grid.decode(obs['image'])
for obj in self.obj_in_mission(self.env.unwrapped.instrs)])
return obs, reward, done, info


class BotActionWrapper(gym.Wrapper):
'''
Wrapper that adds an entry to the info dic of the step function's output that corresponds to whether
the action taken corresponds to the action the GOFAI bot would have taken
'''

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
self.expert = Bot(self.env.unwrapped)
return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)

try:
expert_action = self.expert.step()
except:
expert_action = None

info['bot_action'] = action == expert_action

return obs, reward, done, info
Loading

0 comments on commit c694068

Please sign in to comment.