forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathminibatch_buffer.py
61 lines (53 loc) · 1.9 KB
/
minibatch_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Any, Tuple
import queue
from ray.rllib.utils.annotations import OldAPIStack
@OldAPIStack
class MinibatchBuffer:
"""Ring buffer of recent data batches for minibatch SGD.
This is for use with AsyncSamplesOptimizer.
"""
def __init__(
self,
inqueue: queue.Queue,
size: int,
timeout: float,
num_passes: int,
init_num_passes: int = 1,
):
"""Initialize a minibatch buffer.
Args:
inqueue (queue.Queue): Queue to populate the internal ring buffer
from.
size: Max number of data items to buffer.
timeout: Queue timeout
num_passes: Max num times each data item should be emitted.
init_num_passes: Initial passes for each data item.
Maxiumum number of passes per item are increased to num_passes over
time.
"""
self.inqueue = inqueue
self.size = size
self.timeout = timeout
self.max_initial_ttl = num_passes
self.cur_initial_ttl = init_num_passes
self.buffers = [None] * size
self.ttl = [0] * size
self.idx = 0
def get(self) -> Tuple[Any, bool]:
"""Get a new batch from the internal ring buffer.
Returns:
buf: Data item saved from inqueue.
released: True if the item is now removed from the ring buffer.
"""
if self.ttl[self.idx] <= 0:
self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
self.ttl[self.idx] = self.cur_initial_ttl
if self.cur_initial_ttl < self.max_initial_ttl:
self.cur_initial_ttl += 1
buf = self.buffers[self.idx]
self.ttl[self.idx] -= 1
released = self.ttl[self.idx] <= 0
if released:
self.buffers[self.idx] = None
self.idx = (self.idx + 1) % len(self.buffers)
return buf, released