-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcompare.py
78 lines (66 loc) · 2.24 KB
/
compare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch import nn, Tensor
from torch_flops import TorchFLOPsByFX
import torchanalyse
from thop import profile
from ptflops import get_model_complexity_info
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--linear_no_bias", action='store_true')
parser.add_argument("--add_one", action='store_true')
inp_args = parser.parse_args()
class SimpleModel(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layer = nn.Linear(5, 4, bias=not args.linear_no_bias)
self.__add_one = args.add_one
def forward(self, x: Tensor):
x = self.layer(x)
if self.__add_one:
x += 1.
return x
if __name__ == "__main__":
model = SimpleModel(inp_args)
x = torch.randn(1, 5)
y = model(x)
print("*" * 40 + " Model " + "*" * 40)
print(model)
print(y)
print("=" * 80)
# =========
print("*" * 40 + " torch_flops " + "*" * 40)
flops_counter = TorchFLOPsByFX(model)
# flops_counter.graph_model.graph.print_tabular()
flops_counter.propagate(x)
flops_counter.print_result_table()
flops_1 = flops_counter.print_total_flops(show=False)
print(f"torch_flops: {flops_1} FLOPs")
print("=" * 80)
# =========
print("*" * 40 + " torchanalyse " + "*" * 40)
unit = torchanalyse.Unit(unit_flop='mFLOP')
system = torchanalyse.System(
unit,
frequency=940,
flops=123,
onchip_mem_bw=900,
pe_min_density_support=0.0001,
accelerator_type="structured",
model_on_chip_mem_implications=False,
on_chip_mem_size=32,
)
result_2 = torchanalyse.profiler(model, x, system, unit)
flops_2 = sum(result_2['Flops (mFLOP)'].values) / 1e3
print(f"torchanalyse: {flops_2:.0f} FLOPs")
print("=" * 80)
# =========
print("*" * 40 + " thop " + "*" * 40)
macs_1, params = profile(model, inputs=(x, ))
print(f"thop: {macs_1:.0f} MACs")
print("=" * 80)
# =========
print("*" * 40 + " ptflops " + "*" * 40)
macs_2, params = get_model_complexity_info(model, tuple(x.shape), as_strings=False,
print_per_layer_stat=True, verbose=True)
print(f"ptflops: {macs_2:.0f} MACs")
print("=" * 80)