Skip to content

Commit

Permalink
Add support for D3 in vgg16
Browse files Browse the repository at this point in the history
  • Loading branch information
AdMoR committed Nov 1, 2022
1 parent 02029cd commit 87706fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
11 changes: 8 additions & 3 deletions neural_styles/nn_utils/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class VGG16Layers(Enum):
Conv5_1 = 24
Conv5_2 = 26
Conv5_3 = 28
D3 = -1

def __repr__(self):
return str(self)
Expand Down Expand Up @@ -128,8 +129,12 @@ def load_style_resnet_18(layers, image_size=500):

def load_vgg_16(layer_name, image_size=500, *args):
vgg = models.vgg16(pretrained=True).eval()
modules = list(vgg.children())
replace_relu_with_leaky(modules, ramp=0.1)
modules = list(vgg.modules())
# Replace relu in conv feature
replace_relu_with_leaky(modules[1], ramp=0.1)
# replace relu in dense features
replace_relu_with_leaky(modules[34], ramp=0.1)
vgg = modules[0]

max_layer = -1
if layer_name not in list(VGG16Layers):
Expand All @@ -138,7 +143,7 @@ def load_vgg_16(layer_name, image_size=500, *args):
max_layer = layer_name.value
nn_model = nn.Sequential(vgg.features[0:max_layer])

if layer_name == -1:
if layer_name == VGG16Layers.D3:
return "vgg16_{}".format("classes"), build_subsampler(image_size), vgg
else:
return "vgg16_{}".format(layer_name), build_subsampler(image_size), nn_model
Expand Down
3 changes: 2 additions & 1 deletion neural_styles/svg_optim/excitation_forward_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def gen_vgg16_excitation_func(layer_name, layer_index):
nn_model.to(device)

def func(img_batch, *args, **kwargs):
feature = nn_model.forward(img_batch)[:, layer_index, :, :]
layer_tensor = nn_model.forward(img_batch)
feature = layer_tensor[:, layer_index, :, :] if len(layer_tensor.shape) == 4 else layer_tensor[:, layer_index]
return -torch.sum(feature) + 0.00001 * tvloss(img_batch)

return func
Expand Down

0 comments on commit 87706fb

Please sign in to comment.