Skip to content

Commit

Permalink
fix liznerski#8 and added runner for liznerski#15
Browse files Browse the repository at this point in the history
- No error is raised anymore if no anomalous samples are available during testing for global normalized heatmaps generation (liznerski#8)
- Added a script that loads a snapshot and provides pixel-wise anomaly score tensors for a folder of images (liznerski#15)
  • Loading branch information
liznerski committed Jul 12, 2021
1 parent 7f1119c commit eb30621
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 6 deletions.
85 changes: 85 additions & 0 deletions python/fcdd/runners/run_prediction_with_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
from fcdd.training.fcdd import FCDDTrainer
from fcdd.models.fcdd_cnn_224 import FCDD_CNN224_VGG_F
from fcdd.datasets.image_folder import ImageFolder
from fcdd.datasets.preprocessing import local_contrast_normalization

# -------- MvTec-AD pre-computed min and max values per class after lcn1 has been applied, copied from fcdd.datasets.mvtec --------
min_max_l1 = [
[(-1.3336724042892456, -1.3107913732528687, -1.2445921897888184),
(1.3779616355895996, 1.3779616355895996, 1.3779616355895996)],
[(-2.2404820919036865, -2.3387579917907715, -2.2896201610565186),
(4.573435306549072, 4.573435306549072, 4.573435306549072)],
[(-3.184587001800537, -3.164201259613037, -3.1392977237701416),
(1.6995097398757935, 1.6011602878570557, 1.5209171772003174)],
[(-3.0334954261779785, -2.958242416381836, -2.7701096534729004),
(6.503103256225586, 5.875098705291748, 5.814228057861328)],
[(-3.100773334503174, -3.100773334503174, -3.100773334503174),
(4.27892541885376, 4.27892541885376, 4.27892541885376)],
[(-3.6565306186676025, -3.507692813873291, -2.7635035514831543),
(18.966819763183594, 21.64590072631836, 26.408710479736328)],
[(-1.5192601680755615, -2.2068002223968506, -2.3948357105255127),
(11.564697265625, 10.976534843444824, 10.378695487976074)],
[(-1.3207964897155762, -1.2889339923858643, -1.148416519165039),
(6.854909896850586, 6.854909896850586, 6.854909896850586)],
[(-0.9883341193199158, -0.9822461605072021, -0.9288841485977173),
(2.290637969970703, 2.4007883071899414, 2.3044068813323975)],
[(-7.236185073852539, -7.236185073852539, -7.236185073852539),
(3.3777384757995605, 3.3777384757995605, 3.3777384757995605)],
[(-3.2036616802215576, -3.221003532409668, -3.305514335632324),
(7.022546768188477, 6.115569114685059, 6.310940742492676)],
[(-0.8915618658065796, -0.8669204115867615, -0.8002046346664429),
(4.4255571365356445, 4.642300128936768, 4.305730819702148)],
[(-1.9086798429489136, -2.0004451274871826, -1.929288387298584),
(5.463134765625, 5.463134765625, 5.463134765625)],
[(-2.9547364711761475, -3.17536997795105, -3.143850803375244),
(5.305514812469482, 4.535006523132324, 3.3618252277374268)],
[(-1.2906527519226074, -1.2906527519226074, -1.2906527519226074),
(2.515115737915039, 2.515115737915039, 2.515115737915039)]
]
# ---------------------------------------------------------------------------------------------------------------------------------


# Path to your snapshot.pt
snapshot = "fcdd/data/mvtec_snapshot.pt"

# Pick the architecture that was used for the snapshot (mvtec's architecture defaults to the following)
net = FCDD_CNN224_VGG_F((3, 224, 224), bias=True).cuda()

# Path to a folder that contains a subfolder containing the images (this is required to use PyTorch's ImageFolder dataset).
# For instance, if the images are in foo/my_images/xxx.png, point to foo. Make sure foo contains only one folder (e.g., my_images).
images_path = "fcdd/data/datasets/foo"

# Pick the class the snapshot was trained on.
normal_class = 0

# Use the same test transform as was used for training the snapshot (e.g., for mvtec, per default, the following)
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')),
transforms.Normalize(
min_max_l1[normal_class][0],
[ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])]
)
])

# Create a trainer to use its loss function for computing anomaly scores
ds = ImageFolder(images_path, transform, transforms.Lambda(lambda x: 0))
loader = DataLoader(ds, batch_size=16, num_workers=0)
trainer = FCDDTrainer(net, None, None, (None, None), None, 'fcdd', 8, 0.99, 128) # these parameters will have no effect (used for heatmaps)
trainer.load(snapshot)
all_anomaly_scores = []
for inputs, labels in loader:
inputs = inputs.cuda()
with torch.no_grad():
outputs = trainer.net(inputs)
anomaly_scores = trainer.anomaly_score(trainer.loss(outputs, inputs, labels, reduce='none'))
anomaly_scores = trainer.net.receptive_upsample(anomaly_scores, reception=True, std=8, cpu=False)
all_anomaly_scores.append(anomaly_scores.cpu())

# all_anomaly_scores will be a tensor containing pixel-wise anomaly scores for all images
all_anomaly_scores = torch.cat(all_anomaly_scores)
10 changes: 4 additions & 6 deletions python/fcdd/training/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,11 @@ def _create_heatmaps_picture(self, idx: List[int], name: str, inpshp: torch.Size
rows = []
for s in range(number_of_rows):
rows.append(self._image_processing(imgs[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, qu=1))
err = self.objective != 'ae' and 'train' not in name # training samples might have just one label
if self.objective != 'hsc':
rows.append(
self._image_processing(
ascores[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, qu=self.quantile,
colorize=True, ref=balance_labels(ascores, labels, err) if norm == 'global' else ascores[idx],
colorize=True, ref=balance_labels(ascores, labels, False) if norm == 'global' else ascores[idx],
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
Expand All @@ -500,7 +499,7 @@ def _create_heatmaps_picture(self, idx: List[int], name: str, inpshp: torch.Size
self._image_processing(
grads[idx][s * nrow:s * nrow + nrow], inpshp, self.blur_heatmaps,
self.resdown, qu=self.quantile,
colorize=True, ref=balance_labels(grads, labels, err) if norm == 'global' else grads[idx],
colorize=True, ref=balance_labels(grads, labels, False) if norm == 'global' else grads[idx],
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
Expand Down Expand Up @@ -535,20 +534,19 @@ def _create_singlerow_heatmaps_picture(self, idx: List[int], name: str, inpshp:
"""
for norm in ['local', 'global']:
rows = [self._image_processing(imgs[idx], inpshp, maxres=res, qu=1)]
err = self.objective != 'ae' and 'train' not in name # training samples might have just one label
if self.objective != 'hsc':
rows.append(
self._image_processing(
ascores[idx], inpshp, maxres=res, colorize=True,
ref=balance_labels(ascores, labels, err) if norm == 'global' else None,
ref=balance_labels(ascores, labels, False) if norm == 'global' else None,
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
if grads is not None:
rows.append(
self._image_processing(
grads[idx], inpshp, self.blur_heatmaps, res, colorize=True,
ref=balance_labels(grads, labels, err) if norm == 'global' else None,
ref=balance_labels(grads, labels, False) if norm == 'global' else None,
norm=norm.replace('semi_', ''), # semi case is handled in the line above
)
)
Expand Down

0 comments on commit eb30621

Please sign in to comment.