Skip to content

Commit

Permalink
Create model if not exists in the trainning scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed May 31, 2018
1 parent b740d5a commit 50bac1a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
3 changes: 2 additions & 1 deletion rl/scripts/train_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@
# Define actor-critic model

acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch)
not(args.no_instr), not(args.no_mem), args.arch,
create_if_not_exists=True)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down
3 changes: 2 additions & 1 deletion rl/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
# Define actor-critic model

acmodel = utils.load_model(obss_preprocessor.obs_space, envs[0].action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch)
not(args.no_instr), not(args.no_mem), args.arch,
create_if_not_exists=True)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down
3 changes: 2 additions & 1 deletion rl/scripts/train_wd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@

# Define actor-critic model
acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch)
not(args.no_instr), not(args.no_mem), args.arch,
create_if_not_exists=True)

acmodel.train()

Expand Down
7 changes: 5 additions & 2 deletions rl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ def get_model_path(model_name):
return os.path.join(get_model_dir(model_name), "model.pt")

def load_model(observation_space, action_space, model_name,
use_instr=False, use_memory=False, arch="cnn1"):
use_instr=False, use_memory=False, arch="cnn1",
create_if_not_exists=False):
path = get_model_path(model_name)
if os.path.exists(path):
acmodel = torch.load(path)
else:
elif create_if_not_exists:
acmodel = ACModel(observation_space, action_space,
use_instr, use_memory, arch)
else:
raise ValueError("No model at `{}`".format(path))
acmodel.eval()
return acmodel

Expand Down

0 comments on commit 50bac1a

Please sign in to comment.