forked from CSAILVision/gandissect
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelconfig.py
125 lines (111 loc) · 4.81 KB
/
modelconfig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numbers
import torch
from netdissect.autoeval import autoimport_eval
from netdissect.progress import print_progress
from netdissect.nethook import InstrumentedModel
from netdissect.easydict import EasyDict
def create_instrumented_model(args, **kwargs):
'''
Creates an instrumented model out of a namespace of arguments that
correspond to ArgumentParser command-line args:
model: a string to evaluate as a constructor for the model.
pthfile: (optional) filename of .pth file for the model.
layers: a list of layers to instrument, defaulted if not provided.
edit: True to instrument the layers for editing.
gen: True for a generator model. One-pixel input assumed.
imgsize: For non-generator models, (y, x) dimensions for RGB input.
cuda: True to use CUDA.
The constructed model will be decorated with the following attributes:
input_shape: (usually 4d) tensor shape for single-image input.
output_shape: 4d tensor shape for output.
feature_shape: map of layer names to 4d tensor shape for featuremaps.
retained: map of layernames to tensors, filled after every evaluation.
ablation: if editing, map of layernames to [0..1] alpha values to fill.
replacement: if editing, map of layernames to values to fill.
When editing, the feature value x will be replaced by:
`x = (replacement * ablation) + (x * (1 - ablation))`
'''
args = EasyDict(vars(args), **kwargs)
# Construct the network
if args.model is None:
print_progress('No model specified')
return None
if isinstance(args.model, torch.nn.Module):
model = args.model
else:
model = autoimport_eval(args.model)
# Unwrap any DataParallel-wrapped model
if isinstance(model, torch.nn.DataParallel):
model = next(model.children())
# Load its state dict
meta = {}
if getattr(args, 'pthfile', None) is not None:
data = torch.load(args.pthfile)
if 'state_dict' in data:
meta = {}
for key in data:
if isinstance(data[key], numbers.Number):
meta[key] = data[key]
data = data['state_dict']
model.load_state_dict(data)
# Decide which layers to instrument.
if getattr(args, 'layer', None) is not None:
args.layers = [args.layer]
if getattr(args, 'layers', None) is None:
# Skip wrappers with only one named model
container = model
prefix = ''
while len(list(container.named_children())) == 1:
name, container = next(container.named_children())
prefix += name + '.'
# Default to all nontrivial top-level layers except last.
args.layers = [prefix + name
for name, module in container.named_children()
if type(module).__module__ not in [
# Skip ReLU and other activations.
'torch.nn.modules.activation',
# Skip pooling layers.
'torch.nn.modules.pooling']
][:-1]
print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
# Now wrap the model for instrumentation.
model = InstrumentedModel(model)
model.meta = meta
# Instrument the layers.
model.retain_layers(args.layers)
model.eval()
if args.cuda:
model.cuda()
# Annotate input, output, and feature shapes
annotate_model_shapes(model,
gen=getattr(args, 'gen', False),
imgsize=getattr(args, 'imgsize', None))
return model
def annotate_model_shapes(model, gen=False, imgsize=None):
assert (imgsize is not None) or gen
# Figure the input shape.
if gen:
# We can guess a generator's input shape by looking at the model.
# Examine first conv in model to determine input feature size.
first_layer = [c for c in model.modules()
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
torch.nn.Linear))][0]
# 4d input if convolutional, 2d input if first layer is linear.
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
input_shape = (1, first_layer.in_channels, 1, 1)
else:
input_shape = (1, first_layer.in_features)
else:
# For a classifier, the input image shape is given as an argument.
input_shape = (1, 3) + tuple(imgsize)
# Run the model once to observe feature shapes.
device = next(model.parameters()).device
dry_run = torch.zeros(input_shape).to(device)
with torch.no_grad():
output = model(dry_run)
# Annotate shapes.
model.input_shape = input_shape
model.feature_shape = { layer: feature.shape
for layer, feature in model.retained_features().items() }
model.output_shape = output.shape
return model