forked from minetest/minetest
-
Notifications
You must be signed in to change notification settings - Fork 10
/
data_recorder.py
84 lines (70 loc) · 2.81 KB
/
data_recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import zmq
from minetester.utils import unpack_pb_obs
class DataRecorder:
def __init__(
self,
data_path: os.PathLike,
target_address: str,
timeout: int = 1000,
max_queue_length: int = 1200,
max_attempts: int = 10,
debug: bool = False,
):
self.target_address = target_address
self.data_path = data_path
self.timeout = timeout
self.max_queue_length = max_queue_length
self.max_attempts = max_attempts
# Debug mode prints received actions and does not save to file
self.debug = debug
self._recording = False
# Setup ZMQ
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.RCVTIMEO = self.timeout
self.socket.connect(f"tcp://{self.target_address}")
# Subscribe to all topics
self.socket.setsockopt(zmq.SUBSCRIBE, b"")
# Set maximum message queue length (high water mark)
self.socket.setsockopt(zmq.RCVHWM, self.max_queue_length)
# Set timeout in milliseconds
self.socket.setsockopt(zmq.RCVTIMEO, 1000)
def start(self):
with open(self.data_path, "w") as out:
self._recording = True
num_attempts = 0
while self._recording:
try:
# Receive data
raw_data = self.socket.recv()
num_attempts = 0
if self.debug:
_, rew, terminal, _, action = unpack_pb_obs(raw_data)
action_str = ""
for key in action.keys():
if key != "MOUSE" and action[key]:
action_str += key + ", "
print(f"action={action_str}, rew={rew}, T?={terminal}")
# Write data to new line
if not self.debug:
out.write(str(raw_data) + "\n")
except zmq.ZMQError as err:
if err.errno == zmq.EAGAIN:
print(f"Reception attempts: {num_attempts}")
if num_attempts >= self.max_attempts:
print("Session finished.")
self._recording = False
num_attempts += 1
else:
print(f"ZMQError: {err}")
self._recording = False
def stop(self):
self._recording = False
if __name__ == "__main__":
debug = True # if True, data is not written
address = "localhost:5555"
data_dir = "data.bin"
num_attempts = 10
recorder = DataRecorder(data_dir, address, max_attempts=num_attempts, debug=debug)
recorder.start() # warning: file quickly grows very large