forked from facebookresearch/PyTorch-BigGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstats.py
68 lines (52 loc) · 2.25 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.
from collections import defaultdict
from statistics import mean
from typing import Iterable, Type
from torchbiggraph.types import FloatTensorType
def average_of_sums(*tensors: FloatTensorType) -> float:
return mean(t.sum().item() for t in tensors)
class Stats:
"""A class collecting a set of metrics.
When defining the stats produced by a certain operation (say, training or
evaluation), subclass this class, decorate it with @stats and define the
metrics you want to collect as class attributes with type annotations whose
values are attr.ib() instances. A metric named count is automatically added.
Doing this automatically gives you space-optimized classes (using slots)
equipped with the most common magic methods (__init__, __eq__, ...) plus
some convenience methods to aggregate, convert and format stats (see below).
"""
def __init__(self, *, count: int, **metrics: float) -> None:
self.count = count
self.metrics = metrics
@classmethod
def sum(cls: Type["Stats"], stats: Iterable["Stats"]) -> "Stats":
"""Return a stats whose metrics are the sums of the given stats.
"""
total_metrics = defaultdict(lambda: 0)
for s in stats:
for k, v in s.metrics.items():
total_metrics[k] += v
return cls(count=sum(s.count for s in stats), **total_metrics)
def average(self) -> "Stats":
"""Return these stats with all metrics, except count, averaged.
"""
if self.count == 0:
return self
return type(self)(
count=self.count,
**{k: v / self.count for k, v in self.metrics.items()},
)
def __str__(self) -> str:
return "%s , count: %d" % (
" , ".join("%s: %.6g" % (k, v) for k, v in self.metrics.items()),
self.count,
)
def __eq__(self, other: "Stats") -> bool:
return (isinstance(other, Stats)
and self.count == other.count
and self.metrics == other.metrics)