Skip to content

Commit

Permalink
add tests and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
Animadversio committed Jun 16, 2020
1 parent b5d3768 commit f8beda8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
17 changes: 14 additions & 3 deletions lucent/modelzoo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,26 @@
"""Utility functions for modelzoo models."""

from __future__ import absolute_import, division, print_function
from collections import OrderedDict


def get_model_layers(model):
layers = []
def get_model_layers(model, getLayerRepr=False):
"""
If getLayerRepr is True, return a OrderedDict of layer names, layer representation string pair.
If it's False, just return a list of layer names
"""
layers = OrderedDict() if getLayerRepr else []
# recursive function to get layers
def get_layers(net, prefix=[]):
if hasattr(net, "_modules"):
for name, layer in net._modules.items():
layers.append("_".join(prefix+[name]))
if layer is None:
# e.g. GoogLeNet's aux1 and aux2 layers
continue
if getLayerRepr:
layers["_".join(prefix+[name])] = layer.__repr__()
else:
layers.append("_".join(prefix + [name]))
get_layers(layer, prefix=prefix+[name])

get_layers(model)
Expand Down
16 changes: 0 additions & 16 deletions lucent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
import random
from collections import OrderedDict


def set_seed(seed):
Expand All @@ -28,18 +27,3 @@ def set_seed(seed):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
random.seed(seed)

def lucent_layernames(net, prefix=[]):
""" Return the layername and str representation of the layer """
layernames = OrderedDict()
def hook_layernames(net, prefix=[]):
"""Recursive function to return the layer name"""
if hasattr(net, "_modules"):
for name, layer in net._modules.items():
if layer is None:
# e.g. GoogLeNet's aux1 and aux2 layers
continue
layernames["_".join(prefix+[name])] = layer.__repr__()
hook_layernames(layer, prefix=prefix+[name])
hook_layernames(net, prefix)
return layernames
6 changes: 6 additions & 0 deletions tests/modelzoo/test_inceptionv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def test_inceptionv1_graph_import():
layer_names = util.get_model_layers(model)
for layer_name in important_layer_names:
assert layer_name in layer_names

def test_inceptionv1_import_layer_repr():
model = inceptionv1()
layer_names = util.get_model_layers(model, getLayerRepr=True)
for layer_name in important_layer_names:
assert layer_names[layer_name] == 'CatLayer()'

0 comments on commit f8beda8

Please sign in to comment.