forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_server_input.py
337 lines (290 loc) · 13 KB
/
policy_server_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
from collections import deque
from http.server import HTTPServer, SimpleHTTPRequestHandler
import logging
import queue
from socketserver import ThreadingMixIn
import threading
import time
import traceback
from typing import List
import ray.cloudpickle as pickle
from ray.rllib.env.policy_client import (
_create_embedded_rollout_worker,
Commands,
)
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.sampler import SamplerInput
from ray.rllib.utils.typing import SampleBatchType
logger = logging.getLogger(__name__)
@PublicAPI
class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
"""REST policy server that acts as an offline data source.
This launches a multi-threaded server that listens on the specified host
and port to serve policy requests and forward experiences to RLlib. For
high performance experience collection, it implements InputReader.
For an example, run `examples/serving/cartpole_server.py` along
with `examples/serving/cartpole_client.py --inference-mode=local|remote`.
WARNING: This class is not meant to be publicly exposed. Anyone that can
communicate with this server can execute arbitary code on the machine. Use
this with caution, in isolated environments, and at your own risk.
.. testcode::
:skipif: True
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.policy_client import PolicyClient
from ray.rllib.env.policy_server_input import PolicyServerInput
addr, port = ...
config = (
PPOConfig()
.environment("CartPole-v1")
.offline_data(
input_=lambda ioctx: PolicyServerInput(ioctx, addr, port)
)
# Run just 1 server (in the Algorithm's WorkerSet).
.rollouts(num_rollout_workers=0)
)
algo = config.build()
while True:
algo.train()
client = PolicyClient(
"localhost:9900", inference_mode="local")
eps_id = client.start_episode()
env = gym.make("CartPole-v1")
obs, info = env.reset()
action = client.get_action(eps_id, obs)
_, reward, _, _, _ = env.step(action)
client.log_returns(eps_id, reward)
client.log_returns(eps_id, reward)
algo.stop()
"""
@PublicAPI
def __init__(
self,
ioctx: IOContext,
address: str,
port: int,
idle_timeout: float = 3.0,
max_sample_queue_size: int = 20,
):
"""Create a PolicyServerInput.
This class implements rllib.offline.InputReader, and can be used with
any Algorithm by configuring
[AlgorithmConfig object]
.rollouts(num_rollout_workers=0)
.offline_data(input_=lambda ioctx: PolicyServerInput(ioctx, addr, port))
Note that by setting num_rollout_workers: 0, the algorithm will only create one
rollout worker / PolicyServerInput. Clients can connect to the launched
server using rllib.env.PolicyClient. You can increase the number of available
connections (ports) by setting num_rollout_workers to a larger number. The ports
used will then be `port` + the worker's index.
Args:
ioctx: IOContext provided by RLlib.
address: Server addr (e.g., "localhost").
port: Server port (e.g., 9900).
max_queue_size: The maximum size for the sample queue. Once full, will
purge (throw away) 50% of all samples, oldest first, and continue.
"""
self.rollout_worker = ioctx.worker
# Protect ourselves from having a bottleneck on the server (learning) side.
# Once the queue (deque) is full, we throw away 50% (oldest
# samples first) of the samples, warn, and continue.
self.samples_queue = deque(maxlen=max_sample_queue_size)
self.metrics_queue = queue.Queue()
self.idle_timeout = idle_timeout
# Forwards client-reported metrics directly into the local rollout
# worker.
if self.rollout_worker.sampler is not None:
# This is a bit of a hack since it is patching the get_metrics
# function of the sampler.
def get_metrics():
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
self.rollout_worker.sampler.get_metrics = get_metrics
else:
# If there is no sampler, act like if there would be one to collect
# metrics from
class MetricsDummySampler(SamplerInput):
"""This sampler only maintains a queue to get metrics from."""
def __init__(self, metrics_queue):
"""Initializes a MetricsDummySampler instance.
Args:
metrics_queue: A queue of metrics
"""
self.metrics_queue = metrics_queue
def get_data(self) -> SampleBatchType:
raise NotImplementedError
def get_extra_batches(self) -> List[SampleBatchType]:
raise NotImplementedError
def get_metrics(self) -> List[RolloutMetrics]:
"""Returns metrics computed on a policy client rollout worker."""
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)
# Create a request handler that receives commands from the clients
# and sends data and metrics into the queues.
handler = _make_handler(
self.rollout_worker, self.samples_queue, self.metrics_queue
)
try:
import time
time.sleep(1)
HTTPServer.__init__(self, (address, port), handler)
except OSError:
print(f"Creating a PolicyServer on {address}:{port} failed!")
import time
time.sleep(1)
raise
logger.info(
"Starting connector server at " f"{self.server_name}:{self.server_port}"
)
# Start the serving thread, listening on socket and handling commands.
serving_thread = threading.Thread(name="server", target=self.serve_forever)
serving_thread.daemon = True
serving_thread.start()
# Start a dummy thread that puts empty SampleBatches on the queue, just
# in case we don't receive anything from clients (or there aren't
# any). The latter would block sample collection entirely otherwise,
# even if other workers' PolicyServerInput receive incoming data from
# actual clients.
heart_beat_thread = threading.Thread(
name="heart-beat", target=self._put_empty_sample_batch_every_n_sec
)
heart_beat_thread.daemon = True
heart_beat_thread.start()
@override(InputReader)
def next(self):
# Blocking wait until there is something in the deque.
while len(self.samples_queue) == 0:
time.sleep(0.1)
# Utilize last items first in order to remain as closely as possible
# to operating on-policy.
return self.samples_queue.pop()
def _put_empty_sample_batch_every_n_sec(self):
# Places an empty SampleBatch every `idle_timeout` seconds onto the
# `samples_queue`. This avoids hanging of all RolloutWorkers parallel
# to this one in case this PolicyServerInput does not have incoming
# data (e.g. no client connected) and the driver algorithm uses parallel
# synchronous sampling (e.g. PPO).
while True:
time.sleep(self.idle_timeout)
self.samples_queue.append(SampleBatch())
def _make_handler(rollout_worker, samples_queue, metrics_queue):
# Only used in remote inference mode. We must create a new rollout worker
# then since the original worker doesn't have the env properly wrapped in
# an ExternalEnv interface.
child_rollout_worker = None
inference_thread = None
lock = threading.Lock()
def setup_child_rollout_worker():
nonlocal lock
with lock:
nonlocal child_rollout_worker
nonlocal inference_thread
if child_rollout_worker is None:
(
child_rollout_worker,
inference_thread,
) = _create_embedded_rollout_worker(
rollout_worker.creation_args(), report_data
)
child_rollout_worker.set_weights(rollout_worker.get_weights())
def report_data(data):
nonlocal child_rollout_worker
batch = data["samples"]
batch.decompress_if_needed()
samples_queue.append(batch)
# Deque is full -> purge 50% (oldest samples)
if len(samples_queue) == samples_queue.maxlen:
logger.warning(
"PolicyServerInput queue is full! Purging half of the samples (oldest)."
)
for _ in range(samples_queue.maxlen // 2):
samples_queue.popleft()
for rollout_metric in data["metrics"]:
metrics_queue.put(rollout_metric)
if child_rollout_worker is not None:
child_rollout_worker.set_weights(
rollout_worker.get_weights(), rollout_worker.get_global_vars()
)
class Handler(SimpleHTTPRequestHandler):
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
def do_POST(self):
content_len = int(self.headers.get("Content-Length"), 0)
raw_body = self.rfile.read(content_len)
parsed_input = pickle.loads(raw_body)
try:
response = self.execute_command(parsed_input)
self.send_response(200)
self.end_headers()
self.wfile.write(pickle.dumps(response))
except Exception:
self.send_error(500, traceback.format_exc())
def execute_command(self, args):
command = args["command"]
response = {}
# Local inference commands:
if command == Commands.GET_WORKER_ARGS:
logger.info("Sending worker creation args to client.")
response["worker_args"] = rollout_worker.creation_args()
elif command == Commands.GET_WEIGHTS:
logger.info("Sending worker weights to client.")
response["weights"] = rollout_worker.get_weights()
response["global_vars"] = rollout_worker.get_global_vars()
elif command == Commands.REPORT_SAMPLES:
logger.info(
"Got sample batch of size {} from client.".format(
args["samples"].count
)
)
report_data(args)
# Remote inference commands:
elif command == Commands.START_EPISODE:
setup_child_rollout_worker()
assert inference_thread.is_alive()
response["episode_id"] = child_rollout_worker.env.start_episode(
args["episode_id"], args["training_enabled"]
)
elif command == Commands.GET_ACTION:
assert inference_thread.is_alive()
response["action"] = child_rollout_worker.env.get_action(
args["episode_id"], args["observation"]
)
elif command == Commands.LOG_ACTION:
assert inference_thread.is_alive()
child_rollout_worker.env.log_action(
args["episode_id"], args["observation"], args["action"]
)
elif command == Commands.LOG_RETURNS:
assert inference_thread.is_alive()
if args["done"]:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"], args["done"]
)
else:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"]
)
elif command == Commands.END_EPISODE:
assert inference_thread.is_alive()
child_rollout_worker.env.end_episode(
args["episode_id"], args["observation"]
)
else:
raise ValueError("Unknown command: {}".format(command))
return response
return Handler