Skip to content

Commit

Permalink
Merge pull request aravindr93#33 from aravindr93/v1
Browse files Browse the repository at this point in the history
Pushing version v1 to master
  • Loading branch information
aravindr93 authored Dec 5, 2020
2 parents 57c4a8f + 51bddb2 commit 64f89f5
Show file tree
Hide file tree
Showing 22 changed files with 1,223 additions and 263 deletions.
90 changes: 65 additions & 25 deletions mjrl/algos/model_accel/model_accel_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import time as timer
from torch.autograd import Variable
from mjrl.utils.gym_env import GymEnv
from mjrl.algos.model_accel.nn_dynamics import DynamicsModel

from mjrl.algos.model_accel.nn_dynamics import WorldModel
import mjrl.samplers.core as trajectory_sampler

# utility functions
Expand All @@ -22,33 +21,36 @@


class ModelAccelNPG(NPG):
def __init__(self, fitted_model=None,
def __init__(self, learned_model=None,
refine=False,
kappa=5.0,
plan_horizon=10,
plan_paths=100,
reward_function=None,
termination_function=None,
**kwargs):
super(ModelAccelNPG, self).__init__(**kwargs)
if fitted_model is None:
print("Algorithm requires a NN dynamics model (or list of fitted models)")
if learned_model is None:
print("Algorithm requires a (list of) learned dynamics model")
quit()
elif isinstance(fitted_model, DynamicsModel):
self.fitted_model = [fitted_model]
elif isinstance(learned_model, WorldModel):
self.learned_model = [learned_model]
else:
self.fitted_model = fitted_model
self.refine = refine
self.kappa, self.plan_horizon, self.plan_paths = kappa, plan_horizon, plan_paths
self.learned_model = learned_model
self.refine, self.kappa, self.plan_horizon, self.plan_paths = refine, kappa, plan_horizon, plan_paths
self.reward_function, self.termination_function = reward_function, termination_function

def to(self, device):
# Convert all the networks (except policy network which is clamped to CPU)
# to the specified device
for model in self.fitted_model:
for model in self.learned_model:
model.to(device)
self.baseline.model.to(device)
try: self.baseline.model.to(device)
except: pass

def is_cuda(self):
# Check if any of the networks are on GPU
model_cuda = [model.is_cuda() for model in self.fitted_model]
model_cuda = [model.is_cuda() for model in self.learned_model]
model_cuda = any(model_cuda)
baseline_cuda = next(self.baseline.model.parameters()).is_cuda
return any([model_cuda, baseline_cuda])
Expand All @@ -61,6 +63,12 @@ def train_step(self, N,
gae_lambda=0.97,
num_cpu='max',
env_kwargs=None,
init_states=None,
reward_function=None,
termination_function=None,
truncate_lim=None,
truncate_reward=0.0,
**kwargs,
):

ts = timer.time()
Expand All @@ -78,21 +86,32 @@ def train_step(self, N,
print("Unsupported environment format")
raise AttributeError

# generate paths with fitted dynamics
# get correct behavior for reward and termination
reward_function = self.reward_function if reward_function is None else reward_function
termination_function = self.termination_function if termination_function is None else termination_function
if reward_function: assert callable(reward_function)
if termination_function: assert callable(termination_function)

# simulate trajectories with the learned model(s)
# we want to use the same task instances (e.g. goal locations) for each model in ensemble
paths = []

# NOTE: When running on hardware, we need to load the set of initial states from a pickle file
# init_states = pickle.load(open(<some_file>.pickle, 'rb'))
# init_states = init_states[:N]
init_states = np.array([env.reset() for _ in range(N)])
# NOTE: We can optionally specify a set of initial states to perform the rollouts from
# This is useful for starting rollouts from the states in the replay buffer
init_states = np.array([env.reset() for _ in range(N)]) if init_states is None else init_states
assert type(init_states) == list
assert len(init_states) == N

for model in self.fitted_model:
for model in self.learned_model:
# dont set seed explicitly -- this will make rollouts follow tne global seed
rollouts = policy_rollout(num_traj=N, env=env, policy=self.policy,
fitted_model=model, eval_mode=False, horizon=horizon,
learned_model=model, eval_mode=False, horizon=horizon,
init_state=init_states, seed=None)
self.env.env.env.compute_path_rewards(rollouts)
# use learned reward function if available
if model.learn_reward:
model.compute_path_rewards(rollouts)
else:
rollouts = reward_function(rollouts)
num_traj, horizon, state_dim = rollouts['observations'].shape
for i in range(num_traj):
path = dict()
Expand All @@ -109,10 +128,31 @@ def train_step(self, N,
# a function that can terminate paths appropriately.
# Otherwise, termination is not considered.

try:
paths = self.env.env.env.truncate_paths(paths)
except AttributeError:
pass
if callable(termination_function): paths = termination_function(paths)

# remove paths that are too short
paths = [path for path in paths if path['observations'].shape[0] >= 5]

# additional truncation based on error in the ensembles
if truncate_lim is not None and len(self.learned_model) > 1:
for path in paths:
pred_err = np.zeros(path['observations'].shape[0] - 1)
for model in self.learned_model:
s = path['observations'][:-1]
a = path['actions'][:-1]
s_next = path['observations'][1:]
pred = model.predict(s, a)
model_err = np.mean((s_next - pred)**2, axis=-1)
pred_err = np.maximum(pred_err, model_err)
violations = np.where(pred_err > truncate_lim)[0]
truncated = (not len(violations) == 0)
T = violations[0] + 1 if truncated else obs.shape[0]
T = max(4, T) # we don't want corner cases of very short truncation
path["observations"] = path["observations"][:T]
path["actions"] = path["actions"][:T]
path["rewards"] = path["rewards"][:T]
if truncated: path["rewards"][-1] += truncate_reward
path["terminated"] = False if T == obs.shape[0] else True

if self.save_logs:
self.logger.log_kv('time_sampling', timer.time() - ts)
Expand Down
Loading

0 comments on commit 64f89f5

Please sign in to comment.