forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_memory.py
125 lines (103 loc) · 3.87 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thread-safe and checkpoint-able Replay Memory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os.path
import pickle
import random
import re
import threading
import tensorflow.compat.v1 as tf
class ReplayMemory(object):
"""Replay memory."""
def __init__(self, name, capacity):
self._name = name
self._buffer = collections.deque(maxlen=capacity)
self._lock = threading.Lock()
def _get_checkpoint_filename(self, number):
return 'replay_memory-{}.pkl-{}'.format(self._name, number)
def _get_latest_checkpoint_number(self, checkpoint_dir_path):
checkpoint_numbers = []
if checkpoint_dir_path and tf.gfile.Exists(checkpoint_dir_path):
for filename in tf.gfile.ListDirectory(checkpoint_dir_path):
m = re.match(r'replay_memory-{}\.pkl-(\d+)$'.format(self._name),
filename)
if m:
checkpoint_numbers.append(int(m.group(1)))
if checkpoint_numbers:
checkpoint_numbers.sort(reverse=True)
return checkpoint_numbers[0]
return -1
@property
def size(self):
return len(self._buffer)
@property
def capacity(self):
return self._buffer.maxlen
def clear(self):
with self._lock:
self._buffer.clear()
def extend(self, experience):
with self._lock:
self._buffer.extend(experience)
def batch_extend(self, batch_experience, include_init_state=False):
with self._lock:
for (init_state, state, action, reward, next_state, done,
info) in zip(*batch_experience):
if include_init_state:
self._buffer.append(
[init_state, state, action, reward, next_state, done, info])
else:
self._buffer.append([state, action, reward, next_state, done, info])
def get_buffer(self):
return self._buffer
def sample_with_replacement(self, size):
with self._lock:
return random.choices(self._buffer, k=size)
def sample(self, size):
with self._lock:
return random.sample(self._buffer, size)
def save(self, checkpoint_dir_path, delete_old=False):
if not tf.gfile.Exists(checkpoint_dir_path):
tf.gfile.MakeDirs(checkpoint_dir_path)
with self._lock:
latest_checkpoint_number = self._get_latest_checkpoint_number(
checkpoint_dir_path)
file_path = os.path.join(
checkpoint_dir_path,
self._get_checkpoint_filename(latest_checkpoint_number + 1))
with tf.gfile.Open(file_path, 'wb') as f:
pickle.dump(self._buffer, f)
if delete_old:
file_path = os.path.join(
checkpoint_dir_path,
self._get_checkpoint_filename(latest_checkpoint_number))
if tf.gfile.Exists(file_path):
tf.gfile.Remove(file_path)
def restore(self, checkpoint_dir_path):
with self._lock:
checkpoint_number = self._get_latest_checkpoint_number(
checkpoint_dir_path)
if checkpoint_number < 0:
return
file_path = os.path.join(checkpoint_dir_path,
self._get_checkpoint_filename(checkpoint_number))
tf.logging.info('Restoring replay memory using checkpoint file: %s',
file_path)
with tf.gfile.Open(file_path, 'rb') as f:
self._buffer = pickle.load(f)