-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathshow.py
31 lines (25 loc) · 949 Bytes
/
show.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/env python -u
# -*- coding: utf-8 -*-
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def show_params(nnet):
print("=" * 40, "Model Parameters", "=" * 40)
num_params = 0
for module_name, m in nnet.named_modules():
if module_name == '':
for name, params in m.named_parameters():
print(name, params.size())
i = 1
for j in params.size():
i = i * j
num_params += i
print('[*] Parameter Size: {}'.format(num_params))
print("=" * 98)
def show_model(nnet):
print("=" * 40, "Model Structures", "=" * 40)
for module_name, m in nnet.named_modules():
if module_name == '':
print(m)
print("=" * 98)