Skip to content

Commit

Permalink
1. improving training efficiency
Browse files Browse the repository at this point in the history
2. update RLBench environment wrapper

Signed-off-by: id9502 <[email protected]>
  • Loading branch information
id9502 committed Feb 13, 2020
1 parent 82873f1 commit b517f72
Show file tree
Hide file tree
Showing 27 changed files with 427 additions and 401 deletions.
47 changes: 47 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Compiled source #
###################
*.com
*.class
*.dll
*.exe
*.o
*.so

# Packages #
############
# it's better to unpack these files and commit the raw source
# git has its own built in compression methods
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip

# Logs and databases #
######################
*.log
*.sql
*.sqlite

# OS generated files #
######################
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

# Generated training results #
######################
assets

# Generated Byte-compiled / optimized / DLL files #
######################
__pycache__/
*.py[cod]
*$py.class
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ RL/IRL benchmark with pytorch
Python version >= 3.6

## Install RLFrame
python setup.py install

**TODO: setup.py**

# Getting Started
TODO
Expand Down
Empty file added core/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def __init__(self, config: ParamDict, environment: Environment, policy: Policy,
# filter which will be copied to child thread and also be kept in main thread
self._filter = deepcopy(filter_op)

self._filter.init()
self._filter.to_device(self.device)
self._policy.init()
self._filter.init()
self._policy.to_device(self.device)
self._policy.init()

def __del__(self):
self._filter.finalize()
Expand Down
4 changes: 2 additions & 2 deletions core/agent/agent_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _environment_worker(setups: ParamDict, pipe_cmd, pipe_step, read_lock, step_
np.random.seed(seed)
torch.manual_seed(seed)
environment.init(display=False)
filter_op.init()
filter_op.to_device(torch.device("cpu"))
filter_op.init()
# -1: syncing, 0: waiting for command, 1: waiting for action
local_state = 0
step_buffer = []
Expand Down Expand Up @@ -210,8 +210,8 @@ def _policy_worker(setups: ParamDict, pipe_param, pipe_steps, read_lock, sync_si
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
policy.init()
policy.to_device(device)
policy.init()
# -1: syncing, 0: waiting for state
local_state = 0
max_batchsz = 8
Expand Down
4 changes: 2 additions & 2 deletions core/agent/agent_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def _sampler_worker(setups: ParamDict, pipe_cmd, pipe_param, read_lock, sync_sig
np.random.seed(seed)
torch.manual_seed(seed)
environment.init(display=False)
filter_op.to_device(torch.device("cpu"))
filter_op.init()
filter_op.to_device(device)
policy.init()
policy.to_device(device)
policy.init()

# -1: syncing, 0: waiting for new command, 1: sampling
local_state = 0
Expand Down
80 changes: 34 additions & 46 deletions core/algorithm/ppo.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,40 @@
import torch
import numpy as np
from torch.optim.adam import Adam
from torch.optim.lbfgs import LBFGS
from core.model.policy_with_value import PolicyWithValue
from core.common import StepDictList, ParamDict


def update_value_net(value_net, states, returns, l2_reg):
optimizer = LBFGS(value_net.parameters(), max_iter=25, history_size=5)
def get_tensor(batch, device):
states = torch.as_tensor(batch["states"], dtype=torch.float32, device=device)
actions = torch.as_tensor(batch["actions"], dtype=torch.float32, device=device)
advantages = torch.as_tensor(batch["advantages"], dtype=torch.float32, device=device)
returns = torch.as_tensor(batch["returns"], dtype=torch.float32, device=device)
return states, actions, advantages, returns


def closure():
def update_value_net(value_net, optimizer, states, returns):
for _ in range(25):
optimizer.zero_grad()
values_pred = value_net(states)
value_loss = (values_pred - returns).pow(2).mean()

# weight decay
for param in value_net.parameters():
value_loss += param.pow(2).sum() * l2_reg
value_loss.backward()
return value_loss

optimizer.step(closure)
torch.nn.utils.clip_grad_norm_(value_net.parameters(), 0.5)
optimizer.step()


def get_tensor(policy, reply_memory, device):
policy.estimate_advantage(reply_memory)
states = []
actions = []
advantages = []
returns = []

for b in reply_memory:
advantages.extend([[tr["advantage"]] for tr in b["trajectory"]])
returns.extend([[tr["return"]] for tr in b["trajectory"]])
states.extend([tr['s'] for tr in b["trajectory"]])
actions.extend([tr['a'] for tr in b["trajectory"]])

states = torch.as_tensor(states, dtype=torch.float32, device=device)
actions = torch.as_tensor(actions, dtype=torch.float32, device=device)
advantages = torch.as_tensor(advantages, dtype=torch.float32, device=device)
returns = torch.as_tensor(returns, dtype=torch.float32, device=device)
return states, actions, advantages, returns


def ppo_step(config: ParamDict, replay_memory: StepDictList, policy: PolicyWithValue):
def ppo_step(config: ParamDict, batch: StepDictList, policy: PolicyWithValue):
lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \
config.require("lr", "l2 reg", "clip eps", "optimize policy epochs",
"current training iter", "max iter", "optimize batch size")
lam_entropy = 0.0
states, actions, advantages, returns = get_tensor(policy, replay_memory, policy.device)

"""update critic"""
update_value_net(policy.value_net, states, returns, l2_reg)
states, actions, advantages, returns = get_tensor(batch, policy.device)

"""update policy"""
lr_mult = max(1.0 - i_iter / max_iter, 0.)
clip_epsilon = clip_epsilon * lr_mult
optimizer = Adam(policy.policy_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg)

optimizer_policy = Adam(policy.policy_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg)
optimizer_value = Adam(policy.value_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg)

with torch.no_grad():
fixed_log_probs = policy.policy_net.get_log_prob(states, actions).detach()
Expand All @@ -67,16 +45,26 @@ def ppo_step(config: ParamDict, replay_memory: StepDictList, policy: PolicyWithV
np.random.shuffle(inds)

"""perform mini-batch PPO update"""
for i_b in range(int(np.ceil(states.size(0) / mini_batch_sz))):
for i_b in range(int(np.ceil(inds.size(0) / mini_batch_sz))):
ind = inds[i_b * mini_batch_sz: min((i_b+1) * mini_batch_sz, inds.size(0))]

log_probs, entropy = policy.policy_net.get_log_prob_entropy(states[ind], actions[ind])
ratio = torch.exp(log_probs - fixed_log_probs[ind])
surr1 = ratio * advantages[ind]
surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages[ind]
states_i = states[ind]
actions_i = actions[ind]
returns_i = returns[ind]
advantages_i = advantages[ind]
log_probs_i = fixed_log_probs[ind]

"""update critic"""
update_value_net(policy.value_net, optimizer_value, states_i, returns_i)

"""update policy"""
log_probs, entropy = policy.policy_net.get_log_prob_entropy(states_i, actions_i)
ratio = torch.exp(log_probs - log_probs_i)
surr1 = ratio * advantages_i
surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages_i
policy_surr = -torch.min(surr1, surr2).mean()
policy_surr -= entropy.mean() * lam_entropy
optimizer.zero_grad()
policy_surr = policy_surr - entropy.mean() * lam_entropy
optimizer_policy.zero_grad()
policy_surr.backward()
torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5)
optimizer.step()
optimizer_policy.step()
2 changes: 1 addition & 1 deletion core/algorithm/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def closure():

# weight decay
for param in value_net.parameters():
value_loss += param.pow(2).sum() * l2_reg
value_loss = value_loss + param.pow(2).sum() * l2_reg
value_loss.backward()
return value_loss

Expand Down
3 changes: 0 additions & 3 deletions core/environment/environment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from core.common.types import StepDict, InfoDict, StringList


__all__ = []


class Environment(object):

def __init__(self, task_name: str):
Expand Down
14 changes: 5 additions & 9 deletions core/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class Filter(object):

def __init__(self):
self.device = torch.device("cpu")
pass

def init(self):
pass
Expand All @@ -30,25 +29,22 @@ def operate_currentStep(self, current_step: StepDict) -> StepDict:
"""
decorate current stepDict before transferring to policy net
"""
current_step['s'] = torch.as_tensor(current_step['s'], dtype=torch.float32, device=self.device)
return current_step

def operate_recordStep(self, last_step: StepDict) -> StepDict:
"""
decorate last stepDict before putting to step memory
"""
last_step['a'] = torch.as_tensor(last_step['a'], dtype=torch.float32, device=self.device)
last_step['r'] = torch.as_tensor([last_step['r']], dtype=torch.float32, device=self.device)
return last_step

def operate_stepList(self, step_list: StepDictList, done: bool) -> SampleTraj:
"""
decorate step memory of one roll-out epoch and form single trajectory dict,
must contain keyword "trajectory", "done", "length", "reward sum"
"""
states = torch.stack([step['s'] for step in step_list], dim=0)
actions = torch.stack([step['a'] for step in step_list], dim=0)
rewards = torch.stack([step['r'] for step in step_list], dim=0)
states = torch.as_tensor([step['s'] for step in step_list], dtype=torch.float32, device=self.device)
actions = torch.as_tensor([step['a'] for step in step_list], dtype=torch.float32, device=self.device)
rewards = torch.as_tensor([[step['r']] for step in step_list], dtype=torch.float32, device=self.device)
return {"states": states,
"actions": actions,
"rewards": rewards,
Expand All @@ -68,8 +64,8 @@ def operate_trajectoryList(self, traj_list: List[SampleTraj]) -> (SampleBatch, I
actions = torch.cat(actions, dim=0)
rewards = torch.cat(rewards, dim=0)

steps = torch.as_tensor([b["step"] for b in traj_list], dtype=torch.int)
rsums = torch.as_tensor([b["rsum"] for b in traj_list], dtype=torch.float32)
steps = torch.as_tensor([b["step"] for b in traj_list], dtype=torch.int, device=self.device)
rsums = torch.as_tensor([b["rsum"] for b in traj_list], dtype=torch.float32, device=self.device)

batch = {"states": states,
"actions": actions,
Expand Down
43 changes: 26 additions & 17 deletions core/filter/zfilter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import numpy as np
from core.filter.filter import Filter
from core.math.advantage import advantage
from core.common import StepDictList, SampleTraj, SampleBatch, StepDict, List, ParamDict, InfoDict
Expand All @@ -24,19 +25,25 @@ def __init__(self, advantage_gamma, advantage_tau, clip=10.):
self.tau = advantage_tau

def init(self):
super(ZFilter, self).init()
self.mean = None
self.errsum = None
self.n_step = 0
self.is_fixed = False

def finalize(self):
super(ZFilter, self).finalize()
self.mean = None
self.errsum = None
self.n_step = 0

def reset(self, param: ParamDict):
super(ZFilter, self).reset(param)
self.mean, self.errsum, self.n_step, self.is_fixed =\
param.require("zfilter mean", "zfilter errsum", "zfilter n_step", "fixed filter")
if self.mean is not None:
self.mean = self.mean.cpu().numpy()
self.errsum = self.errsum.cpu().numpy()

def operate_currentStep(self, current_step: StepDict) -> StepDict:
"""
Expand All @@ -47,26 +54,30 @@ def operate_currentStep(self, current_step: StepDict) -> StepDict:
current_step = super(ZFilter, self).operate_currentStep(current_step)
x = current_step['s']

if self.mean is None:
self.mean = x.copy()
self.errsum = np.zeros_like(self.mean)

if not self.is_fixed:
if self.mean is None:
self.mean = x.clone()
self.errsum = torch.zeros_like(self.mean)
else:
self.n_step += 1
oldM = self.mean.clone()
self.mean = self.mean + (x - self.mean) / self.n_step
self.errsum = self.errsum + (x - oldM) * (x - self.mean)
self.n_step += 1
oldM = self.mean
self.mean = self.mean + (x - self.mean) / self.n_step
self.errsum = self.errsum + (x - oldM) * (x - self.mean)

std = torch.sqrt(self.errsum / (self.n_step - 1)) if self.n_step > 1 else self.mean
std = np.sqrt(self.errsum / (self.n_step - 1)) if self.n_step > 1 else self.mean

x -= self.mean
x /= std + 1e-8
if self.clip is not None:
x.clamp_(-self.clip, self.clip)
x = x.clip(-self.clip, self.clip)

current_step['s'] = x
return current_step

def operate_recordStep(self, last_step: StepDict) -> StepDict:
last_step = super(ZFilter, self).operate_recordStep(last_step)
return last_step

def operate_stepList(self, step_list: StepDictList, done: bool) -> SampleTraj:
traj = super(ZFilter, self).operate_stepList(step_list, done)

Expand All @@ -76,8 +87,8 @@ def operate_stepList(self, step_list: StepDictList, done: bool) -> SampleTraj:
traj["advantages"] = advantages
traj["returns"] = returns

traj["filter state dict"] = {"zfilter mean": self.mean.clone(),
"zfilter errsum": self.errsum.clone(),
traj["filter state dict"] = {"zfilter mean": self.mean.copy(),
"zfilter errsum": self.errsum.copy(),
"zfilter n_step": self.n_step,
"fixed filter": self.is_fixed}
return traj
Expand Down Expand Up @@ -107,12 +118,10 @@ def operate_trajectoryList(self, traj_list: List[SampleTraj]) -> (SampleBatch, I

def getStateDict(self) -> ParamDict:
state_dict = super(ZFilter, self).getStateDict()
return state_dict + ParamDict({"zfilter mean": self.mean,
"zfilter errsum": self.errsum,
return state_dict + ParamDict({"zfilter mean": torch.as_tensor(self.mean, dtype=torch.float32, device=torch.device("cpu")) if self.mean is not None else None,
"zfilter errsum": torch.as_tensor(self.errsum, dtype=torch.float32, device=torch.device("cpu")) if self.errsum is not None else None,
"zfilter n_step": self.n_step,
"fixed filter": self.is_fixed})

def to_device(self, device: torch.device):
if self.mean is not None:
self.mean.to(device)
self.errsum.to(device)
super(ZFilter, self).to_device(device)
Empty file added core/model/nets/__init__.py
Empty file.
Loading

0 comments on commit b517f72

Please sign in to comment.