Skip to content

Commit

Permalink
added evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Sep 22, 2020
1 parent c21a450 commit a47bc0c
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 44 deletions.
32 changes: 20 additions & 12 deletions rls/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,25 +387,23 @@ def apex(self) -> NoReturn:
if self.train_args['apex'] == 'learner':
from rls.distribute.apex.learner import learner
learner(
ip=self.train_args['apex_learner_ip'],
port=self.train_args['apex_learner_port'],
env=self.env,
model=self.model,
env=self.env
ip=self.train_args['apex_learner_ip'],
port=self.train_args['apex_learner_port']
)
return

if self.train_args['apex'] == 'worker':
elif self.train_args['apex'] == 'worker':
from rls.distribute.apex.worker import worker
worker(
env=self.env,
model=self.model,
learner_ip=self.train_args['apex_learner_ip'],
learner_port=self.train_args['apex_learner_port'],
buffer_ip=self.train_args['apex_buffer_ip'],
buffer_port=self.train_args['apex_buffer_port'],
model=self.model,
env=self.env)
return

if self.train_args['apex'] == 'buffer':
worker_args=self.train_args['apex_worker_args']
)
elif self.train_args['apex'] == 'buffer':
from rls.distribute.apex.buffer import buffer
buffer(
ip=self.train_args['apex_buffer_ip'],
Expand All @@ -414,4 +412,14 @@ def apex(self) -> NoReturn:
learner_port=self.train_args['apex_learner_port'],
buffer_args=self.train_args['apex_buffer_args']
)
return
elif self.train_args['apex'] == 'evaluator':
from rls.distribute.apex.evaluator import evaluator
evaluator(
env=self.env,
model=self.model,
learner_ip=self.train_args['apex_learner_ip'],
learner_port=self.train_args['apex_learner_port'],
evaluator_args=self.train_args['apex_evaluator_args']
)

return
26 changes: 15 additions & 11 deletions rls/distribute/apex/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
proto2exps_and_tderror
from rls.distribute.utils.check import check_port_in_use
from rls.memories.replay_buffer import PrioritizedExperienceReplay
from rls.utils.display import colorize
from rls.utils.logging_utils import get_logger
logger = get_logger(__name__)

Expand All @@ -30,19 +31,17 @@ def __init__(self, learner_ip, learner_port, buffer, lock):
self.lock = lock

def run(self):
train_time = 0
while True:
if self.buffer.is_lg_batch_size:
with self.lock:
exps = self.buffer.sample()
exps, idxs = self.buffer.sample(return_index=True)
prios = self.buffer.get_IS_w().reshape(-1, 1)
td_error = self.learner_stub.SendExperienceGetPriorities(
exps_and_prios2proto(
exps=exps,
prios=prios))
td_error = proto2numpy(td_error)
self.buffer.update(td_error, train_time)
train_time += 1
self.buffer.update(td_error, idxs)
self.learner_channel.close()


Expand All @@ -53,18 +52,23 @@ def __init__(self, buffer, lock):
self.lock = lock

def SendTrajectories(self, request_iterator: Iterator[apex_datatype_pb2.ListNDarray], context) -> apex_datatype_pb2.Nothing:
'''
worker向buffer发送一批trajectory
'''
for traj in request_iterator:
self.buffer.add(*batch_proto2numpy(traj))
logger.info('receive Trajectories from worker.')
return apex_datatype_pb2.Nothing()

def SendExperiences(self, request_iterator: Iterator[apex_datatype_pb2.ExpsAndTDerror], context) -> apex_datatype_pb2.Nothing:
self.lock.acquire()
for request in request_iterator:
data, td_error = proto2exps_and_tderror(request)
self.buffer.apex_add_batch(td_error, *data)
logger.info('receive Experiences from worker.')
self.lock.release()
'''
worker向buffer发送一批经验
'''
with self.lock:
for request in request_iterator:
data, td_error = proto2exps_and_tderror(request)
self.buffer.apex_add_batch(td_error, *data)
logger.info('receive Experiences from worker.')
return apex_datatype_pb2.Nothing()


Expand All @@ -84,7 +88,7 @@ def buffer(
apex_buffer_pb2_grpc.add_BufferServicer_to_server(BufferServicer(buffer=buffer, lock=threadLock), server)
server.add_insecure_port(':'.join([ip, port]))
server.start()
logger.info('start buffer success.')
logger.info(colorize('start buffer success.', color='green'))

learn_thread = LearnThread(learner_ip, learner_port, buffer, threadLock)
learn_thread.start()
Expand Down
9 changes: 8 additions & 1 deletion rls/distribute/apex/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@ apex_buffer_args:
epsilon: 0.01
global_v: false

# 每个worker探索程度不同
apex_worker_args:
rollout_interval: 1 # seconds
is_send_traj: False

apex_evaluator_args:
pull_interval: 2 # episode
episode_sleep: 5 # seconds
# 每个worker探索程度不同
91 changes: 91 additions & 0 deletions rls/distribute/apex/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import grpc
import time
import numpy as np

from rls.utils.np_utils import \
SMA, \
arrprint
from rls.distribute.pb2 import \
apex_datatype_pb2, \
apex_learner_pb2_grpc
from rls.distribute.utils.apex_utils import \
batch_proto2numpy
from rls.utils.logging_utils import get_logger
logger = get_logger(__name__)


class EvalProc(object):
'''
评估策略性能
'''

def __init__(self, env, model, evaluator_args, callback_func):
super().__init__()
self.env = env
self.model = model
self.callback_func = callback_func

for k, v in evaluator_args.items():
setattr(self, k, v)

def run(self):
n = self.env.n
i = 1 if self.env.obs_type == 'visual' else 0
state = [np.full((n, 0), []), np.full((n, 0), [])]
sma = SMA(100)
total_step = 0
episode = 0

while True:
if episode % self.pull_interval:
self.model.set_worker_params(self.callback_func())
logger.info('pull parameters from success.')
episode += 1
self.model.reset()
state[i] = self.env.reset()
dones_flag = np.zeros(self.env.n)
step = 0
rets = np.zeros(self.env.n)
last_done_step = -1
while True:
step += 1
# env.render(record=False)
action = self.model.choose_action(s=state[0], visual_s=state[1])
_, reward, done, info, state[i] = self.env.step(action)
rets += (1 - dones_flag) * reward
dones_flag = np.sign(dones_flag + done)
self.model.partial_reset(done)
total_step += 1
if all(dones_flag):
if last_done_step == -1:
last_done_step = step
break

if step >= 200:
break

sma.update(rets)
self.model.writer_summary(
episode,
reward_mean=rets.mean(),
reward_min=rets.min(),
reward_max=rets.max(),
step=last_done_step,
**sma.rs
)
logger.info(f'Eps: {episode:3d} | S: {step:4d} | LDS {last_done_step:4d} | R: {arrprint(rets, 2)}')
time.sleep(self.episode_sleep)


def evaluator(env,
model,
learner_ip,
learner_port,
evaluator_args):
learner_channel = grpc.insecure_channel(':'.join([learner_ip, learner_port]))
learner_stub = apex_learner_pb2_grpc.LearnerStub(learner_channel)

evalproc = EvalProc(env, model, evaluator_args, callback_func=lambda: batch_proto2numpy(learner_stub.GetParams(apex_datatype_pb2.Nothing())))
evalproc.run()

learner_channel.close()
18 changes: 9 additions & 9 deletions rls/distribute/apex/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import threading
import numpy as np

from concurrent import futures

from rls.utils.np_utils import \
SMA, \
arrprint

from concurrent import futures

from rls.distribute.pb2 import \
apex_datatype_pb2, \
apex_learner_pb2_grpc
Expand All @@ -20,6 +19,7 @@
proto2exps_and_prios
from rls.distribute.utils.check import check_port_in_use
from rls.common.collector import GymCollector
from rls.utils.display import colorize
from rls.utils.logging_utils import get_logger
logger = get_logger(__name__)

Expand Down Expand Up @@ -73,7 +73,7 @@ def run(self):
step=last_done_step,
**sma.rs
)
print(f'Eps: {episode:3d} | S: {step:4d} | LDS {last_done_step:4d} | R: {arrprint(rets, 2)}')
logger.info(f'Eps: {episode:3d} | S: {step:4d} | LDS {last_done_step:4d} | R: {arrprint(rets, 2)}')
time.sleep(5)


Expand Down Expand Up @@ -104,22 +104,22 @@ def SendExperienceGetPriorities(self, request: apex_datatype_pb2.ExpsAndPrios, c
self.train_step += 1
if self.train_step % 100 == 0:
self.model.save_checkpoint(train_step=self.train_step)
# logger.info('send new priorities to buffer.')
logger.info('send new priorities to buffer...')
return td_error


def learner(ip, port, model, env):
def learner(env, model, ip, port):
check_port_in_use(port, ip, try_times=10, server_name='learner')
assert hasattr(model, 'apex_learn') and hasattr(model, 'apex_cal_td'), 'this algorithm does not support Ape-X learning for now.'

server = grpc.server(futures.ThreadPoolExecutor())
apex_learner_pb2_grpc.add_LearnerServicer_to_server(LearnerServicer(model), server)
server.add_insecure_port(':'.join([ip, port]))
server.start()
logger.info('start learner success.')
logger.info(colorize('start learner success.', color='green'))

eval_thread = EvalThread(env, model)
eval_thread.start()
# eval_thread = EvalThread(env, model)
# eval_thread.start()

# GymCollector.evaluate(env, model)
server.wait_for_termination()
36 changes: 26 additions & 10 deletions rls/distribute/apex/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,33 @@
logger = get_logger(__name__)


def worker(learner_ip,
class WorkerCls(object):

def __init__(self, env, model, worker_args, callback_func):
self.env = env
self.model = model
self.callback_func = callback_func
for k, v in worker_args.items():
setattr(self, k, v)

def run(self):
while True:
model.set_worker_params(self.callback_func())
if self.is_send_traj:
buffer_stub.SendTrajectories(GymCollector.run_trajectory(env, model))
else:
for _ in range(10):
buffer_stub.SendExperiences(GymCollector.run_exps_stream(env, model))
time.sleep(self.rollout_interval)


def worker(env,
model,
learner_ip,
learner_port,
buffer_ip,
buffer_port,
model,
env):
worker_args):
learner_channel = grpc.insecure_channel(':'.join([learner_ip, learner_port]))
buffer_channel = grpc.insecure_channel(':'.join([buffer_ip, buffer_port]))

Expand All @@ -34,13 +55,8 @@ def worker(learner_ip,
# arr_list = [np.arange(4).reshape(2, 2), np.arange(3).astype(np.int32), np.array([])]
# learner_stub.SendBatchNumpyArray(batch_numpy2proto(arr_list))

while True:
model.set_worker_params(
batch_proto2numpy(learner_stub.GetParams(apex_datatype_pb2.Nothing())))
for _ in range(10):
buffer_stub.SendExperiences(GymCollector.run_exps_stream(env, model))
time.sleep(0.5)
# buffer_stub.SendTrajectories(GymCollector.run_trajectory(env, model))
workercls = WorkerCls(env, model, worker_args, callback_func=lambda: batch_proto2numpy(learner_stub.GetParams(apex_datatype_pb2.Nothing())))
workercls.run()

learner_channel.close()
buffer_channel.close()
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
-t,--train-step=<n> 总的训练次数, specify the training step that optimize the policy model [default: None]
-u,--unity 是否使用unity客户端, whether training with UNITY3D editor [default: False]
--apex=<str> i.e. "learner"/"worker"/"buffer" [default: None]
--apex=<str> i.e. "learner"/"worker"/"buffer"/"evaluator" [default: None]
--unity-env=<name> 指定unity环境的名字, specify the name of training environment of UNITY3D [default: None]
--config-file=<file> 指定模型的超参数config文件, specify the path of training configuration file [default: None]
--store-dir=<file> 指定要保存模型、日志、数据的文件夹路径, specify the directory that store model, log and others [default: None]
Expand Down

0 comments on commit a47bc0c

Please sign in to comment.