From f8beda80819c869a1e6884363e60a0782e97976d Mon Sep 17 00:00:00 2001 From: Animadversio Date: Tue, 16 Jun 2020 14:44:25 -0500 Subject: [PATCH] add tests and refactor code --- lucent/modelzoo/util.py | 17 ++++++++++++++--- lucent/util.py | 16 ---------------- tests/modelzoo/test_inceptionv1.py | 6 ++++++ 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/lucent/modelzoo/util.py b/lucent/modelzoo/util.py index d8f1555..d96798a 100644 --- a/lucent/modelzoo/util.py +++ b/lucent/modelzoo/util.py @@ -16,15 +16,26 @@ """Utility functions for modelzoo models.""" from __future__ import absolute_import, division, print_function +from collections import OrderedDict -def get_model_layers(model): - layers = [] +def get_model_layers(model, getLayerRepr=False): + """ + If getLayerRepr is True, return a OrderedDict of layer names, layer representation string pair. + If it's False, just return a list of layer names + """ + layers = OrderedDict() if getLayerRepr else [] # recursive function to get layers def get_layers(net, prefix=[]): if hasattr(net, "_modules"): for name, layer in net._modules.items(): - layers.append("_".join(prefix+[name])) + if layer is None: + # e.g. GoogLeNet's aux1 and aux2 layers + continue + if getLayerRepr: + layers["_".join(prefix+[name])] = layer.__repr__() + else: + layers.append("_".join(prefix + [name])) get_layers(layer, prefix=prefix+[name]) get_layers(model) diff --git a/lucent/util.py b/lucent/util.py index 0906c17..2ce264c 100644 --- a/lucent/util.py +++ b/lucent/util.py @@ -19,7 +19,6 @@ import torch import random -from collections import OrderedDict def set_seed(seed): @@ -28,18 +27,3 @@ def set_seed(seed): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic=True random.seed(seed) - -def lucent_layernames(net, prefix=[]): - """ Return the layername and str representation of the layer """ - layernames = OrderedDict() - def hook_layernames(net, prefix=[]): - """Recursive function to return the layer name""" - if hasattr(net, "_modules"): - for name, layer in net._modules.items(): - if layer is None: - # e.g. GoogLeNet's aux1 and aux2 layers - continue - layernames["_".join(prefix+[name])] = layer.__repr__() - hook_layernames(layer, prefix=prefix+[name]) - hook_layernames(net, prefix) - return layernames \ No newline at end of file diff --git a/tests/modelzoo/test_inceptionv1.py b/tests/modelzoo/test_inceptionv1.py index 8d1faf4..67f8820 100644 --- a/tests/modelzoo/test_inceptionv1.py +++ b/tests/modelzoo/test_inceptionv1.py @@ -39,3 +39,9 @@ def test_inceptionv1_graph_import(): layer_names = util.get_model_layers(model) for layer_name in important_layer_names: assert layer_name in layer_names + +def test_inceptionv1_import_layer_repr(): + model = inceptionv1() + layer_names = util.get_model_layers(model, getLayerRepr=True) + for layer_name in important_layer_names: + assert layer_names[layer_name] == 'CatLayer()' \ No newline at end of file