Skip to content

Commit

Permalink
Sim Hash working
Browse files Browse the repository at this point in the history
  • Loading branch information
BoogaQ committed Aug 18, 2020
1 parent 0365d77 commit 6e440d2
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 12 deletions.
6 changes: 6 additions & 0 deletions .ipynb_checkpoints/Untitled-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 4
}
234 changes: 234 additions & 0 deletions Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
"C:\\Users\\ktiju\\anaconda3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
]
}
],
"source": [
"from hashlib import md5\n",
"import gym\n",
"import mujoco_py\n",
"from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env"
]
},
{
"cell_type": "code",
"execution_count": 271,
"metadata": {},
"outputs": [],
"source": [
"env = make_vec_env(\"Swimmer-v2\", 4)"
]
},
{
"cell_type": "code",
"execution_count": 272,
"metadata": {},
"outputs": [],
"source": [
"obs = np.array(env.reset())"
]
},
{
"cell_type": "code",
"execution_count": 435,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([4., 4., 0., 0., 4., 0., 0., 0., 0., 1., 3., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])"
]
},
"execution_count": 435,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.tree"
]
},
{
"cell_type": "code",
"execution_count": 296,
"metadata": {},
"outputs": [],
"source": [
"import numpy\n",
"# SumTree\n",
"# a binary tree data structure where the parent’s value is the sum of its children\n",
"class SumTree:\n",
" write = 0\n",
" def __init__(self, capacity):\n",
" self.capacity = capacity\n",
" self.tree = numpy.zeros(2 * capacity - 1)\n",
" self.data = numpy.zeros(capacity, dtype=object)\n",
" self.n_entries = 0\n",
" self.pending_idx = set()\n",
"\n",
" # update to the root node\n",
" def _propagate(self, idx, change):\n",
" parent = (idx - 1) // 2\n",
" self.tree[parent] += change\n",
" if parent != 0:\n",
" self._propagate(parent, change)\n",
"\n",
" # find sample on leaf node\n",
" def _retrieve(self, idx, s):\n",
" left = 2 * idx + 1\n",
" right = left + 1\n",
"\n",
" if left >= len(self.tree):\n",
" return idx\n",
"\n",
" if s <= self.tree[left]:\n",
" return self._retrieve(left, s)\n",
" else:\n",
" return self._retrieve(right, s - self.tree[left])\n",
"\n",
" def total(self):\n",
" return self.tree[0]\n",
"\n",
" # store priority and sample\n",
" def add(self, p, data):\n",
" idx = self.write + self.capacity - 1\n",
" self.pending_idx.add(idx)\n",
"\n",
" self.data[self.write] = data\n",
" self.update(idx, p)\n",
"\n",
" self.write += 1\n",
" if self.write >= self.capacity:\n",
" self.write = 0\n",
"\n",
" if self.n_entries < self.capacity:\n",
" self.n_entries += 1\n",
"\n",
" # update priority\n",
" def update(self, idx, p):\n",
" if idx not in self.pending_idx:\n",
" return\n",
" self.pending_idx.remove(idx)\n",
" change = p - self.tree[idx]\n",
" self.tree[idx] = p\n",
" self._propagate(idx, change)\n",
"\n",
" # get priority and sample\n",
" def get(self, s):\n",
" idx = self._retrieve(0, s)\n",
" dataIdx = idx - self.capacity + 1\n",
" self.pending_idx.add(idx)\n",
" return (idx, self.tree[idx], dataIdx)"
]
},
{
"cell_type": "code",
"execution_count": 408,
"metadata": {},
"outputs": [],
"source": [
"tree = SumTree(10)"
]
},
{
"cell_type": "code",
"execution_count": 412,
"metadata": {},
"outputs": [],
"source": [
"tree.add(3, (\"s2\", \"a2\", \"r2\"))"
]
},
{
"cell_type": "code",
"execution_count": 413,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([('s1', 'a1', 'r1'), ('s2', 'a2', 'r2'), 0, 0, 0, 0, 0, 0, 0, 0],\n",
" dtype=object)"
]
},
"execution_count": 413,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.data"
]
},
{
"cell_type": "code",
"execution_count": 434,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(9, 1.0, 0)"
]
},
"execution_count": 434,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.get(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Binary file modified __pycache__/buffer.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/ppo.cpython-37.pyc
Binary file not shown.
25 changes: 23 additions & 2 deletions buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logger
from util import RunningMeanStd

from collections import namedtuple
from collections import namedtuple, defaultdict



Expand Down Expand Up @@ -123,7 +123,7 @@ def _normalize_reward(reward,
return reward

class RolloutStorage(BaseBuffer):
def __init__(self, buffer_size, n_envs, obs_space, action_space, gae_lam = 0.95, gamma = 0.99):
def __init__(self, buffer_size, n_envs, obs_space, action_space, gae_lam = 0.95, gamma = 0.99, sim_hash = False):
super(RolloutStorage, self).__init__(buffer_size, obs_space, action_space, n_envs = n_envs)

self.gae_lam = gae_lam
Expand All @@ -137,6 +137,15 @@ def __init__(self, buffer_size, n_envs, obs_space, action_space, gae_lam = 0.95,
self.full = False
self.reset()

self.count_table = defaultdict(lambda:0)
self.A = np.random.randn(16, self.obs_shape[0])

if sim_hash:
self.do_hash = True
self.beta = 0.1
else:
self.do_hash = False

self.RolloutSample = namedtuple('RolloutSample', ['observations', 'actions', 'old_values', 'old_log_probs', 'advantages', 'returns'])

def reset(self):
Expand All @@ -163,6 +172,8 @@ def add(self, obs, action, reward, value, mask, log_prob):
"""

self.observations[self.pos] = np.array(obs).copy()
if self.do_hash:
reward = self.sim_hash(obs, reward)
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.masks[self.pos] = np.array(mask).copy()
Expand All @@ -173,6 +184,16 @@ def add(self, obs, action, reward, value, mask, log_prob):
if self.pos == self.buffer_size:
self.full = True

def sim_hash(self, obs, rewards):
hashed_array = np.greater(np.dot(self.A, obs.T).T, 0).astype(int)

hashed_array = np.array([np.array_str(array).replace('[', '').replace(']', '').replace(' ', '') for array in hashed_array])

for index, key in enumerate(hashed_array):
self.count_table[key] += 1
rewards[index] += self.beta/np.sqrt(self.count_table[key])
return rewards


def compute_returns_and_advantages(self, last_value, dones):
"""
Expand Down
12 changes: 9 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
if __name__ == "__main__":

for i in range(10):
model = EvolutionStrategy("Reacher-v2", [64,64], nsr_plateu = 1.5, nsr_range = [0,5], nsr_update = 0.05, sigma = 0.1, learning_rate = 0.03, decay = 0.995,
novelty_param = 1, num_threads = 4)
model.run(total_timesteps = 1e+3, log_interval = 1, reward_target = -10, log_to_file = True)
model = PPO(env_id = "Acrobot-v1", hidden_size = 64, lr = 0.001, gamma = 0.99, gae_lam = 0.95, max_grad_norm = 5,
nstep = 1024, batch_size = 64, n_epochs = 10, clip_range = 0.2, ent_coef = 0.01)
model.learn(total_timesteps = 1e+6, log_interval = 1, log_to_file = False)


# ICM reacher hidden_size = 64, lr = 0.001, int_hidden_size = 16, int_lr = 0.0001, int_vf_coef = 0.05, nstep = 256
Expand All @@ -47,3 +47,9 @@

# PPO_RND(env_id = "InvertedDoublePendulum-v2", hidden_size = 64, lr = 0.00025, gamma = 0.99, gae_lam = 0.95, max_grad_norm = 5, rnd_start = 1e+3,
# nstep = 2048, batch_size = 128, n_epochs = 10, clip_range = 0.2, ent_coef = 0.00, int_vf_coef = 0.5, int_hidden_size = 32, int_lr = 0.001)

"""
model = EvolutionStrategy("Reacher-v2", [64,64], nsr_plateu = 1.5, nsr_range = [0,5], nsr_update = 0.05, sigma = 0.1, learning_rate = 0.03, decay = 0.995,
novelty_param = 1, num_threads = 4)
model.run(total_timesteps = 1e+3, log_interval = 1, reward_target = -10, log_to_file = True)
"""
13 changes: 6 additions & 7 deletions ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from abc import ABC, abstractmethod

from collections import deque
from collections import deque, defaultdict
import multiprocessing

import logger
Expand Down Expand Up @@ -46,10 +46,8 @@ def __init__(self,
max_grad_norm):

self.env_id = env_id
try:
self.env = make_env(env_id, n_envs = 4)
except:
self.env = make_env(env_id, n_envs = 1, vec_env_cls = DummyVecEnv)

self.env = make_env(env_id, n_envs = 4)

self.num_envs = self.env.num_envs if isinstance(self.env, VecEnv) else 1
self.state_dim = self.env.observation_space.shape[0]
Expand Down Expand Up @@ -120,11 +118,12 @@ def __init__(self, *,
ent_coef = .01,
vf_coef = 1,
max_grad_norm = 0.2,
hidden_size = 128):
hidden_size = 128,
sim_hash = True):
super(PPO, self).__init__(env_id, lr, nstep, batch_size, n_epochs, gamma, gae_lam, clip_range, ent_coef, vf_coef, max_grad_norm)

self.policy = Policy(self.env, hidden_size)
self.rollout = RolloutStorage(nstep, self.num_envs, self.env.observation_space, self.env.action_space, gae_lam = gae_lam, gamma = gamma)
self.rollout = RolloutStorage(nstep, self.num_envs, self.env.observation_space, self.env.action_space, gae_lam = gae_lam, gamma = gamma, sim_hash = sim_hash)
self.optimizer = optim.Adam(self.policy.net.parameters(), lr = lr, eps = 1e-5)

self.last_obs = self.env.reset()
Expand Down

0 comments on commit 6e440d2

Please sign in to comment.