Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix issue with loading demos when making demos * reindentation + small bug fix * - Added head for extra binary information that can be used for an auxiliary supervised loss - PEP8 reformatting * Line too long - changed * Modified how experiences are collected when there are extra binary information to be used from the environment * A few comments that were super helpful to me. I guess this commit doesn't need to be in the PR * Taking into account the extra binary information to evaluate a supervised loss * Small modifications to the model to output extra_logits (no need for sigmoid layer) * - Added the possibility of specifying how many extra binary information to use from the environment (they have to specified in the `info` part of the gym step function). - Logging of the corresponding supervised loss and accuracy * clearer help for an argparse argument * The environment yields, at each step, if the new state is already visited or not * supervised loss coeff can be a float, and not necessarily an int * fix a bug at evaluation time due to extra outputs of the model when there is an auxiliary loss * quick hack to show the supervised loss coefficient in the model name for easier comparison * - Reseeding after initializing the model to make sure to get consistent results. - This commit doesn't need to be in the PR. * typo * Defining the extra info head after the actor and the critic so that the initialization process makes the results consistent between when we're not using extra info and when we're using it with a supervised loss coef of 0. * Log total loss * small bug fix * typo * added logging of prevalence in the supervised auxiliary task for debugging/understanding * default extra binary info to False for retro compatibility * added more binary info * - fixed bug in enjoy - made enjoy and evaluate compatible with the extra binary info setting * - reuse previous deleted normalization of weights - define as many extra_heads as passed to the model through a dictionary - define an extra_predictions dictionary to be returned - always return a dictionary in the forward model to avoid too many conditionals in scripts that use the model * use extra-info as a list argument containing the names of the extra info wanted from the environment * Because acmodel.forward returns a dictionary. This means that at each call, we should change the containing variable * return extra information at each step * - collect the experiments the right way - update the parameters in the presence of extra info for supervised aux tasks - adequate logging - change of model * change-list * Use ModuleDict instead of dict. !! REQUIRES pytorch 0.4.1 !! * Add a new aux loss - requires a small change in minigrid * reintroduce the prevalences * stop using numbers to check for presence of objects in observation * factorization * docstring * removed unnecessary argument from ModelAgent * fix bug introduced in bf10a286a89f8e15ee46d7cb7d41f374601dda28 * fix logging issue * - add a conditional in evaluation of grad norm because sometimes we use different model origins - add the option of using a pre-trained model and have the fine-tuned version of it saved elsewhere (otherwise, one cannot use the same pre-trained model to finetune 2 different models in parallel) * allows using extra heads even for pre-trained models * fix small bug with 'continuous' type * fix small bug with 'continuous' type * change comment - again, doesn't have to be merged * use a new class instead of overloading the ppo file * model refactorization * more refactorization * update requirement of pytorch version * add some comments for the classes introduced * bug fix * - Use a wrapper for supervised auxiliary losses - added binary info: does the agent think they did the same action as would the bot - bugfix in rl/utils/supervised_losses * comment on function * move function to wrapper * - Use different wrappers for each auxiliary task - Rename extra info to aux info * rename extra info to aux info * revert file commited by mistake * remove float() when dealing with binary information * rename args
- Loading branch information