forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearner_thread.py
137 lines (124 loc) · 5.6 KB
/
learner_thread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import copy
import queue
import threading
from typing import Dict, Optional
from ray.util.timer import _Timer
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.util.iter import _NextValueNotReady
tf1, tf, tfv = try_import_tf()
@OldAPIStack
class LearnerThread(threading.Thread):
"""Background thread that updates the local model from sample trajectories.
The learner thread communicates with the main thread through Queues. This
is needed since Ray operations can only be run on the main thread. In
addition, moving heavyweight gradient ops session runs off the main thread
improves overall throughput.
"""
def __init__(
self,
local_worker: RolloutWorker,
minibatch_buffer_size: int,
num_sgd_iter: int,
learner_queue_size: int,
learner_queue_timeout: int,
):
"""Initialize the learner thread.
Args:
local_worker: process local rollout worker holding
policies this thread will call learn_on_batch() on
minibatch_buffer_size: max number of train batches to store
in the minibatching buffer
num_sgd_iter: number of passes to learn on per train batch
learner_queue_size: max size of queue of inbound
train batches to this thread
learner_queue_timeout: raise an exception if the queue has
been empty for this long in seconds
"""
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_worker = local_worker
self.inqueue = queue.Queue(maxsize=learner_queue_size)
self.outqueue = queue.Queue()
self.minibatch_buffer = MinibatchBuffer(
inqueue=self.inqueue,
size=minibatch_buffer_size,
timeout=learner_queue_timeout,
num_passes=num_sgd_iter,
init_num_passes=num_sgd_iter,
)
self.queue_timer = _Timer()
self.grad_timer = _Timer()
self.load_timer = _Timer()
self.load_wait_timer = _Timer()
self.daemon = True
self.policy_ids_updated = []
self.learner_info = {}
self.stopped = False
self.num_steps = 0
def run(self) -> None:
# Switch on eager mode if configured.
if self.local_worker.config.framework_str == "tf2":
tf1.enable_eager_execution()
while not self.stopped:
self.step()
def step(self) -> Optional[_NextValueNotReady]:
with self.queue_timer:
try:
batch, _ = self.minibatch_buffer.get()
except queue.Empty:
return _NextValueNotReady()
with self.grad_timer:
# Use LearnerInfoBuilder as a unified way to build the final
# results dict from `learn_on_loaded_batch` call(s).
# This makes sure results dicts always have the same structure
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
# tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=1)
if self.local_worker.config.policy_states_are_swappable:
self.local_worker.lock()
multi_agent_results = self.local_worker.learn_on_batch(batch)
if self.local_worker.config.policy_states_are_swappable:
self.local_worker.unlock()
self.policy_ids_updated.extend(list(multi_agent_results.keys()))
for pid, results in multi_agent_results.items():
learner_info_builder.add_learn_on_batch_results(results, pid)
self.learner_info = learner_info_builder.finalize()
self.num_steps += 1
# Put tuple: env-steps, agent-steps, and learner info into the queue.
self.outqueue.put((batch.count, batch.agent_steps(), self.learner_info))
self.learner_queue_size.push(self.inqueue.qsize())
def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict:
"""Add internal metrics to a result dict."""
def timer_to_ms(timer):
return round(1000 * timer.mean, 3)
if overwrite_learner_info:
result["info"].update(
{
"learner_queue": self.learner_queue_size.stats(),
LEARNER_INFO: copy.deepcopy(self.learner_info),
"timing_breakdown": {
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
"learner_load_time_ms": timer_to_ms(self.load_timer),
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
},
}
)
else:
result["info"].update(
{
"learner_queue": self.learner_queue_size.stats(),
"timing_breakdown": {
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
"learner_load_time_ms": timer_to_ms(self.load_timer),
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
},
}
)
return result