Skip to content

Commit

Permalink
Merge pull request mila-iqia#57 from maximecb/revert-56-check_loading
Browse files Browse the repository at this point in the history
Revert "load_model only loads the model, doesn't create anything"
  • Loading branch information
lcswillems authored Jun 1, 2018
2 parents 4a2c44e + bffe0be commit 3f3e923
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
14 changes: 6 additions & 8 deletions rl/scripts/train_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
import torch.nn.functional as F
import torch_rl
from model import ACModel

import utils

Expand Down Expand Up @@ -82,10 +81,9 @@

# Define actor-critic model

acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
not(args.no_instr), not(args.no_mem), args.arch)
if args.model:
acmodel = utils.load_model(args.model)
acmodel = utils.load_model(obss_preprocessor.obs_space, env.action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch,
create_if_not_exists=True)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down Expand Up @@ -158,7 +156,7 @@

log_entropies.append(entropy.data[0])
log_policy_losses.append(policy_loss.data[0])

update_end_time = time.time()

# Print logs
Expand All @@ -175,7 +173,7 @@
"U {} | FPS {:04.0f} | D {} | H {:.3f} | pL {: .3f}"
.format(i, fps, duration,
log_entropy, log_policy_loss))

if args.tb:
writer.add_scalar("FPS", fps, i)
writer.add_scalar("duration", total_ellapsed_time, i)
Expand All @@ -192,4 +190,4 @@
utils.save_model(acmodel, model_name)
logger.log("Model is saved.")
if torch.cuda.is_available():
acmodel.cuda()
acmodel.cuda()
12 changes: 5 additions & 7 deletions rl/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys
import torch
import torch_rl
from model import ACModel

import utils

Expand Down Expand Up @@ -94,10 +93,9 @@

# Define actor-critic model

acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
not(args.no_instr), not(args.no_mem), args.arch)
if args.model:
acmodel = utils.load_model(args.model)
acmodel = utils.load_model(obss_preprocessor.obs_space, envs[0].action_space, model_name,
not(args.no_instr), not(args.no_mem), args.arch,
create_if_not_exists=True)
if torch.cuda.is_available():
acmodel.cuda()

Expand Down Expand Up @@ -140,7 +138,7 @@
update_start_time = time.time()
logs = algo.update_parameters()
update_end_time = time.time()

num_frames += logs["num_frames"]
i += 1

Expand Down Expand Up @@ -185,4 +183,4 @@
utils.save_model(acmodel, model_name)
logger.info("Model is saved.")
if torch.cuda.is_available():
acmodel.cuda()
acmodel.cuda()
17 changes: 13 additions & 4 deletions rl/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch

from model import ACModel
import utils

def get_model_dir(model_name):
Expand All @@ -9,13 +10,21 @@ def get_model_dir(model_name):
def get_model_path(model_name):
return os.path.join(get_model_dir(model_name), "model.pt")

def load_model(model_name):
def load_model(observation_space, action_space, model_name,
use_instr=False, use_memory=False, arch="cnn1",
create_if_not_exists=False):
path = get_model_path(model_name)
if not os.path.exists(path):
if os.path.exists(path):
acmodel = torch.load(path)
elif create_if_not_exists:
acmodel = ACModel(observation_space, action_space,
use_instr, use_memory, arch)
else:
raise ValueError("No model at `{}`".format(path))
return torch.load(path)
acmodel.eval()
return acmodel

def save_model(acmodel, model_name):
path = get_model_path(model_name)
utils.create_folders_if_necessary(path)
torch.save(acmodel, path)
torch.save(acmodel, path)

0 comments on commit 3f3e923

Please sign in to comment.