forked from ppwwyyxx/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_events.py
64 lines (55 loc) · 2.39 KB
/
test_events.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates.
import json
import os
import tempfile
import unittest
from detectron2.utils.events import CommonMetricPrinter, EventStorage, JSONWriter
class TestEventWriter(unittest.TestCase):
def testScalar(self):
with tempfile.TemporaryDirectory(
prefix="detectron2_tests"
) as dir, EventStorage() as storage:
json_file = os.path.join(dir, "test.json")
writer = JSONWriter(json_file)
for k in range(60):
storage.put_scalar("key", k, smoothing_hint=False)
if (k + 1) % 20 == 0:
writer.write()
storage.step()
writer.close()
with open(json_file) as f:
data = [json.loads(l) for l in f]
self.assertTrue([int(k["key"]) for k in data] == [19, 39, 59])
def testScalarMismatchedPeriod(self):
with tempfile.TemporaryDirectory(
prefix="detectron2_tests"
) as dir, EventStorage() as storage:
json_file = os.path.join(dir, "test.json")
writer = JSONWriter(json_file)
for k in range(60):
if k % 17 == 0: # write in a differnt period
storage.put_scalar("key2", k, smoothing_hint=False)
storage.put_scalar("key", k, smoothing_hint=False)
if (k + 1) % 20 == 0:
writer.write()
storage.step()
writer.close()
with open(json_file) as f:
data = [json.loads(l) for l in f]
self.assertTrue([int(k.get("key2", 0)) for k in data] == [17, 0, 34, 0, 51, 0])
self.assertTrue([int(k.get("key", 0)) for k in data] == [0, 19, 0, 39, 0, 59])
self.assertTrue([int(k["iteration"]) for k in data] == [17, 19, 34, 39, 51, 59])
def testPrintETA(self):
with EventStorage() as s:
p1 = CommonMetricPrinter(10)
p2 = CommonMetricPrinter()
s.put_scalar("time", 1.0)
s.step()
s.put_scalar("time", 1.0)
s.step()
with self.assertLogs("detectron2.utils.events") as logs:
p1.write()
self.assertIn("eta", logs.output[0])
with self.assertLogs("detectron2.utils.events") as logs:
p2.write()
self.assertNotIn("eta", logs.output[0])