Skip to content

Commit

Permalink
Add linearly weighted activations of channel, neuron, neurongroup as …
Browse files Browse the repository at this point in the history
…objective.

Add a function in util.py to output the module names to ease usage for custom networks.
  • Loading branch information
Animadversio committed Jun 1, 2020
1 parent 044317a commit 64ab8f7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
37 changes: 37 additions & 0 deletions lucent/optvis/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ def inner(model):
return -model(layer)[:, n_channel].mean()
return inner

@wrap_objective()
def neuron_weight(layer, weight=None, x=None, y=None, batch=None):
""" Linearly weighted channel activation at one location as objective
weight: a torch Tensor vector same length as channel.
"""
@handle_batch(batch)
def inner(model):
layer_t = model(layer)
layer_t = _extract_act_pos(layer_t, x, y)
if weight is None:
return -layer_t.mean()
else:
return -(layer_t.squeeze() * weight).mean()
return inner

@wrap_objective()
def channel_weight(layer, weight, batch=None):
""" Linearly weighted channel activation as objective
weight: a torch Tensor vector same length as channel. """
@handle_batch(batch)
def inner(model):
layer_t = model(layer)
return -(layer_t * weight.view(1, -1, 1, 1)).mean()
return inner

@wrap_objective()
def localgroup_weight(layer, weight=None, x=None, y=None, wx=1, wy=1, batch=None):
""" Linearly weighted channel activation around some spot as objective
weight: a torch Tensor vector same length as channel. """
@handle_batch(batch)
def inner(model):
layer_t = model(layer)
if weight is None:
return -(layer_t[:, :, y:y + wy, x:x + wx]).mean()
else:
return -(layer_t[:, :, y:y + wy, x:x + wx] * weight.view(1, -1, 1, 1)).mean()
return inner

def _torch_blur(tensor, out_c=3):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down
16 changes: 16 additions & 0 deletions lucent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import random
from collections import OrderedDict


def set_seed(seed):
Expand All @@ -27,3 +28,18 @@ def set_seed(seed):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
random.seed(seed)

def lucent_layernames(net, prefix=[]):
""" Return the layername and str representation of the layer """
layernames = OrderedDict()
def hook_layernames(net, prefix=[]):
"""Recursive function to return the layer name"""
if hasattr(net, "_modules"):
for name, layer in net._modules.items():
if layer is None:
# e.g. GoogLeNet's aux1 and aux2 layers
continue
layernames["_".join(prefix+[name])] = layer.__repr__()
hook_layernames(layer, prefix=prefix+[name])
hook_layernames(net, prefix)
return layernames

0 comments on commit 64ab8f7

Please sign in to comment.