-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
65 lines (45 loc) · 1.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch as ch
import numpy as np
def split_generator(generator):
"""
Split tuple elements in a generator into lists
"""
return list(map(list, zip(*generator)))
def flatten_model(model,whitelist_keywords=None):
"""
flatten the modules in a model into a list
whitelist_keywords: modules containing keywords will not be split
"""
names,modules = split_generator(model.named_children())
for i,nm in enumerate(zip(names,modules)):
n,m = nm
if whitelist_keywords is not None:
if any([n.find(each) != -1 for each in whitelist_keywords]):
continue ## do not split
if len(list(m.children()))>0:
new_names,new_modules = split_generator(m.named_children())
modules += new_modules
modules[i] = None
names += [f'{names[i]}::{each}' for each in new_names]
names[i] = None
return [each for each in names if each is not None], [each for each in modules if each is not None]
def add_hooks_preact_resnet18(model, config, verbose=False):
"""
Add hooks to preact resnet
"""
names,modules = flatten_model(model)
assert len(names) == len(modules)
## add hooks to bns only. bn outputs are always passed through relus (even for skip connections)
norm_module = ch.nn.modules.BatchNorm2d if config.use_bn else ch.nn.modules.Identity
layer_ids = np.asarray([i for i,each in enumerate(modules) if (type(each)==norm_module)])
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
for each in layer_ids:
modules[each].register_forward_hook(get_activation(names[each]))
layer_names = np.sort(np.asarray(names)[layer_ids])
if verbose:
print('Adding Hook to',layer_names)
return model, layer_names, activation