Skip to content

Commit

Permalink
Model contains instructions and memory by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed May 31, 2018
1 parent 37a4738 commit 1854c6c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
11 changes: 6 additions & 5 deletions rl/scripts/train_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
help="number of epochs (default: 10)")
parser.add_argument("--batch-size", type=int, default=256,
help="batch size (default: 256)")
parser.add_argument("--model-instr", action="store_true", default=False,
help="use instructions in the model")
parser.add_argument("--model-mem", action="store_true", default=False,
help="use memory in the model")
parser.add_argument("--no-instr", action="store_true", default=False,
help="don't use instructions in the model")
parser.add_argument("--no-mem", action="store_true", default=False,
help="don't use memory in the model")
parser.add_argument("--arch", default='cnn1',
help="image embedding architecture")
args = parser.parse_args()
Expand Down Expand Up @@ -80,7 +80,8 @@

# Define actor-critic model

acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name)
acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down
10 changes: 5 additions & 5 deletions rl/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
help="number of epochs for PPO (default: 4)")
parser.add_argument("--batch-size", type=int, default=256,
help="batch size for PPO (default: 256)")
parser.add_argument("--model-instr", action="store_true", default=False,
help="use instructions in the model")
parser.add_argument("--model-mem", action="store_true", default=False,
help="use memory in the model")
parser.add_argument("--no-instr", action="store_true", default=False,
help="don't use instructions in the model")
parser.add_argument("--no-mem", action="store_true", default=False,
help="don't use memory in the model")
parser.add_argument("--arch", default='cnn1',
help="image embedding architecture")
parser.add_argument("--exp-name", default=None,
Expand Down Expand Up @@ -94,7 +94,7 @@
# Define actor-critic model

acmodel = utils.load_model(obss_preprocessor.obs_space, envs[0].action_space, model_name,
args.model_instr, args.model_mem, args.arch)
not(args.no_instr), not(args.no_mem), args.arch)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down
10 changes: 5 additions & 5 deletions rl/scripts/train_wd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
help="number of epochs (default: 10)")
parser.add_argument("--batch-size", type=int, default=10,
help="batch size (default: 10)")
parser.add_argument("--model-instr", action="store_true", default=False,
help="use instructions in the model")
parser.add_argument("--model-mem", action="store_true", default=False,
help="use memory in the model")
parser.add_argument("--no-instr", action="store_true", default=False,
help="don't use instructions in the model")
parser.add_argument("--no-mem", action="store_true", default=False,
help="don't use memory in the model")
parser.add_argument("--arch", default='cnn1',
help="image embedding architecture")
parser.add_argument("--discount", type=float, default=0.99,
Expand Down Expand Up @@ -89,7 +89,7 @@

# Define actor-critic model
acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name,
args.model_instr, args.model_mem, args.arch)
not(args.no_instr), not(args.no_mem), args.arch)

acmodel.train()

Expand Down

0 comments on commit 1854c6c

Please sign in to comment.