forked from MarcoMeter/recurrent-ppo-truncated-bptt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathworker.py
62 lines (53 loc) · 1.98 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import multiprocessing
import multiprocessing.connection
from utils import create_env
def worker_process(remote: multiprocessing.connection.Connection, config:dict) -> None:
"""Executes the threaded interface to the environment.
Arguments:
remote {multiprocessing.connection.Connection} -- Parent thread
config {dict} -- Configuration of the training environment
"""
# Spawn training environment
try:
env = create_env(config)
except KeyboardInterrupt:
pass
# Communication interface of the environment process
while True:
try:
cmd, data = remote.recv()
if cmd == "step":
remote.send(env.step(data))
elif cmd == "reset":
remote.send(env.reset())
elif cmd == "close":
remote.send(env.close())
remote.close()
break
else:
raise NotImplementedError
except Exception as e:
raise WorkerException(e)
class Worker:
"""A worker that runs one environment on one process."""
child: multiprocessing.connection.Connection
process: multiprocessing.Process
def __init__(self, env_config:dict):
"""
Arguments:
env_config {dict} -- Configuration of the training environment
"""
self.child, parent = multiprocessing.Pipe()
self.process = multiprocessing.Process(target=worker_process, args=(parent, env_config))
self.process.start()
import tblib.pickling_support
tblib.pickling_support.install()
import sys
class WorkerException(Exception):
"""Exception that is raised in the worker process and re-raised in the main process."""
def __init__(self, ee):
self.ee = ee
__, __, self.tb = sys.exc_info()
super(WorkerException, self).__init__(str(ee))
def re_raise(self):
raise (self.ee, None, self.tb)