Skip to content

Commit

Permalink
change default arguments for 224x224 datasets + fix semi-global norma…
Browse files Browse the repository at this point in the history
…lization
  • Loading branch information
liznerski committed Sep 8, 2020
1 parent 501c7f9 commit 175a3c2
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 18 deletions.
85 changes: 82 additions & 3 deletions python/fcdd/models/fcdd_cnn_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ def __init__(self, in_shape, **kwargs):
super().__init__(in_shape, **kwargs)
self.conv1 = self._create_conv2d(in_shape[0], 8, 5, bias=self.bias, padding=2)
self.bn2d1 = nn.BatchNorm2d(8, eps=1e-04, affine=self.bias)
self.pool1 = self._create_maxpool2d(3, 2, 1) # 8 x 112 x 112
self.pool1 = self._create_maxpool2d(3, 2, 1) # 32 x 112 x 112

self.conv2 = self._create_conv2d(8, 32, 5, bias=self.bias, padding=2)
self.bn2d2 = nn.BatchNorm2d(32, eps=1e-04, affine=self.bias)
self.pool2 = self._create_maxpool2d(3, 2, 1) # 32 x 56 x 56
self.pool2 = self._create_maxpool2d(3, 2, 1) # 128 x 56 x 56

self.conv3 = self._create_conv2d(32, 64, 3, bias=self.bias, padding=1)
self.bn2d3 = nn.BatchNorm2d(64, eps=1e-04, affine=self.bias)
self.conv4 = self._create_conv2d(64, 128, 3, bias=self.bias, padding=1)
self.bn2d4 = nn.BatchNorm2d(128, eps=1e-04, affine=self.bias)
self.pool3 = self._create_maxpool2d(3, 2, 1) # 128 x 28 x 28
self.pool3 = self._create_maxpool2d(3, 2, 1) # 256 x 28 x 28

self.conv5 = self._create_conv2d(128, 128, 3, bias=self.bias, padding=1)
self.encoder_out_shape = (128, 28, 28)
Expand Down Expand Up @@ -48,6 +48,50 @@ def forward(self, x, ad=True):
return x


class FCDD_CNN224_W(FCDDNet):
def __init__(self, in_shape, **kwargs):
super().__init__(in_shape, **kwargs)
self.conv1 = self._create_conv2d(in_shape[0], 32, 5, bias=self.bias, padding=2)
self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=self.bias)
self.pool1 = self._create_maxpool2d(3, 2, 1) # 32 x 112 x 112

self.conv2 = self._create_conv2d(32, 128, 5, bias=self.bias, padding=2)
self.bn2d2 = nn.BatchNorm2d(128, eps=1e-04, affine=self.bias)
self.pool2 = self._create_maxpool2d(3, 2, 1) # 128 x 56 x 56

self.conv3 = self._create_conv2d(128, 256, 3, bias=self.bias, padding=1)
self.bn2d3 = nn.BatchNorm2d(256, eps=1e-04, affine=self.bias)
self.conv4 = self._create_conv2d(256, 256, 3, bias=self.bias, padding=1)
self.bn2d4 = nn.BatchNorm2d(256, eps=1e-04, affine=self.bias)
self.pool3 = self._create_maxpool2d(3, 2, 1) # 256 x 28 x 28

self.conv5 = self._create_conv2d(256, 128, 3, bias=self.bias, padding=1)
self.encoder_out_shape = (128, 28, 28)
self.conv_final = self._create_conv2d(128, 1, 1, bias=self.bias)

def forward(self, x, ad=True):
x = self.conv1(x)
x = F.leaky_relu(self.bn2d1(x))
x = self.pool1(x)

x = self.conv2(x)
x = F.leaky_relu(self.bn2d2(x))
x = self.pool2(x)

x = self.conv3(x)
x = F.leaky_relu(self.bn2d3(x))
x = self.conv4(x)
x = F.leaky_relu(self.bn2d4(x))
x = self.pool3(x)

x = self.conv5(x)

if ad:
x = self.conv_final(x) # n x heads x h' x w'

return x


class FCDD_AE224(BaseNet):
encoder_cls = FCDD_CNN224

Expand Down Expand Up @@ -80,3 +124,38 @@ def forward(self, x):
x = self.deconv5(x)
x = torch.sigmoid(x)
return x


class FCDD_AE224_W(BaseNet):
encoder_cls = FCDD_CNN224_W

def __init__(self, encoder, **kwargs):
super().__init__(encoder.in_shape, bias=encoder.bias, **kwargs)
self.encoder = encoder

self.bn0 = nn.BatchNorm2d(128, eps=1e-05, affine=self.bias)
self.deconv1 = nn.ConvTranspose2d(128, 256, 3, bias=self.bias, padding=1)
self.bn1 = nn.BatchNorm2d(256, eps=1e-04, affine=self.bias)
self.deconv2 = nn.ConvTranspose2d(256, 256, 3, bias=self.bias, padding=1)
self.bn2 = nn.BatchNorm2d(256, eps=1e-04, affine=self.bias)
self.deconv3 = nn.ConvTranspose2d(256, 128, 3, bias=self.bias, padding=1)
self.bn3 = nn.BatchNorm2d(128, eps=1e-04, affine=self.bias)
self.deconv4 = nn.ConvTranspose2d(128, 32, 3, bias=self.bias, padding=1)
self.bn4 = nn.BatchNorm2d(32, eps=1e-04, affine=self.bias)
self.deconv5 = nn.ConvTranspose2d(32, self.in_shape[0], 5, bias=self.bias, padding=2)

def forward(self, x):
x = self.encoder(x, ad=False)
x = F.leaky_relu(self.bn0(x))
x = self.deconv1(x)
x = F.interpolate(F.leaky_relu(self.bn1(x)), scale_factor=2)
x = self.deconv2(x)
x = F.leaky_relu(self.bn2(x))
x = self.deconv3(x)
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
x = self.deconv4(x)
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
x = self.deconv5(x)
x = torch.sigmoid(x)
return x

6 changes: 3 additions & 3 deletions python/fcdd/runners/argparse_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __call__(self, parser: ArgumentParser) -> ArgumentParser:
)
parser.add_argument('-d', '--dataset', type=str, default='mvtec', choices=DS_CHOICES)
parser.add_argument(
'-n', '--net', type=str, default='FCDD_CNN224', choices=choices(),
'-n', '--net', type=str, default='FCDD_CNN224_W', choices=choices(),
help='Chooses a network architecture to train. Note that not all architectures fit every objective. '
)
parser.add_argument(
Expand Down Expand Up @@ -189,7 +189,7 @@ def __call__(self, parser: ArgumentParser) -> ArgumentParser:
parser = super().__call__(parser)
parser.set_defaults(
batch_size=16, acc_batches=8, supervise_mode='malformed_normal',
gauss_std=12, weight_decay=1e-5, epochs=200
gauss_std=12, weight_decay=1e-5, epochs=200, preproc='lcnaug1'
)
return parser

Expand All @@ -201,7 +201,7 @@ def __call__(self, parser: ArgumentParser) -> ArgumentParser:
batch_size=20, acc_batches=10, epochs=600,
optimizer_type='adam', scheduler_type='milestones',
lr_sched_param=[0.1, 400, 500], noise_mode='imagenet22k',
dataset='imagenet', net='FCDD_CNN224'
dataset='imagenet'
)
return parser

Expand Down
4 changes: 2 additions & 2 deletions python/fcdd/runners/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def extract_viz_ids(dir: str, cls: str, it: int):
return viz_ids


NET_TO_HSC = {'FCDD_CNN224': 'CNN224', 'FCDD_CNN32_S': 'CNN32', 'FCDD_CNN28': 'CNN28'}
NET_TO_AE = {'FCDD_CNN224': 'AE224', 'FCDD_CNN32_S': 'AE32', 'FCDD_CNN28': 'AE28'}
NET_TO_HSC = {'FCDD_CNN224': 'CNN224', 'FCDD_CNN32_S': 'CNN32', 'FCDD_CNN28': 'CNN28', 'FCDD_CNN224_W': 'CNN224'}
NET_TO_AE = {'FCDD_CNN224': 'AE224', 'FCDD_CNN32_S': 'AE32', 'FCDD_CNN28': 'AE28', 'FCDD_CNN224_W': 'AE224_W'}


class BaseRunner(object):
Expand Down
23 changes: 13 additions & 10 deletions python/fcdd/training/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,16 @@ def _create_heatmaps_picture(self, idx: [int], name: str, inpshp: torch.Size, su
rows.append(
self._image_processing(
ascores[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, qu=self.quantile,
colorize=True, norm=norm, ref=ascores if norm == 'global' else None
colorize=True, ref=ascores if norm == 'global' else ascores[idx],
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
if grads is not None:
rows.append(
self._image_processing(
grads[idx][s * nrow:s * nrow + nrow], inpshp, self.blur_heatmaps, self.resdown,
qu=self.quantile, colorize=True, norm=norm, ref=grads if norm == 'global' else None
qu=self.quantile, colorize=True, ref=grads if norm == 'global' else grads[idx],
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
if gtmaps is not None:
Expand Down Expand Up @@ -517,15 +519,17 @@ def _create_singlerow_heatmaps_picture(self, idx: [int], name: str, inpshp: torc
if self.objective != 'hsc':
rows.append(
self._image_processing(
ascores[idx], inpshp, maxres=res, colorize=True, norm=norm,
ref=ascores if norm == 'global' else None
ascores[idx], inpshp, maxres=res, colorize=True,
ref=ascores if norm == 'global' else None,
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
if grads is not None:
rows.append(
self._image_processing(
grads[idx], inpshp, self.blur_heatmaps, res, colorize=True, norm=norm,
ref=grads if norm == 'global' else None
grads[idx], inpshp, self.blur_heatmaps, res, colorize=True,
ref=grads if norm == 'global' else None,
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
if gtmaps is not None:
Expand All @@ -551,10 +555,9 @@ def _image_processing(self, imgs: Tensor, input_shape: torch.Size, blur: bool =
None: no normalization.
'local': normalizes each image w.r.t. itself only.
'global': normalizes each image w.r.t. to ref.
'semi_global': normalizes each image w.r.t. all images.
:param qu: quantile used for normalization, qu=1 yields the typical 0-1 normalization.
:param colorize: whether to colorize grayscaled images using colormaps (-> pseudocolored heatmaps!).
:param ref: a tensor of images used for global normalization.
:param ref: a tensor of images used for global normalization (defaults to imgs).
:param cmap: the colormap that is used to colorize grayscaled images.
:param inplace: whether to perform the operations inplace.
:return: transformed tensor of images
Expand Down Expand Up @@ -598,9 +601,9 @@ def _image_processing(self, imgs: Tensor, input_shape: torch.Size, blur: bool =
# apply requested normalization
if norm is not None:
apply_norm = {
'local': self.__local_norm, 'global': self.__global_norm, 'semi_global': self.__global_norm
'local': self.__local_norm, 'global': self.__global_norm,
}
imgs = apply_norm[norm](imgs, qu, ref if norm == 'global' else None)
imgs = apply_norm[norm](imgs, qu, ref)

# if image is grayscaled, colorize, i.e. provide a pseudocolored heatmap!
if colorize:
Expand Down

0 comments on commit 175a3c2

Please sign in to comment.