Skip to content

Commit

Permalink
Extended run_prediction_with_snapshot script to generate heatmaps if …
Browse files Browse the repository at this point in the history
…requested
  • Loading branch information
liznerski committed Nov 4, 2021
1 parent 7df3b9a commit 9ded37d
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions python/fcdd/runners/run_prediction_with_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fcdd.models.fcdd_cnn_224 import FCDD_CNN224_VGG_F
from fcdd.datasets.image_folder import ImageFolder
from fcdd.datasets.preprocessing import local_contrast_normalization
from fcdd.util.logging import Logger

# -------- MvTec-AD pre-computed min and max values per class after lcn1 has been applied, copied from fcdd.datasets.mvtec --------
min_max_l1 = [
Expand Down Expand Up @@ -67,19 +68,37 @@
)
])

# [optional] to generate heatmaps, define a logger (with the path where the heatmaps should be saved to) and a quantile
logger = None # Logger("fcdd/data/results/foo")
quantile = 0.97

# Create a trainer to use its loss function for computing anomaly scores
ds = ImageFolder(images_path, transform, transforms.Lambda(lambda x: 0))
loader = DataLoader(ds, batch_size=16, num_workers=0)
trainer = FCDDTrainer(net, None, None, (None, None), None, 'fcdd', 8, 0.99, 128) # these parameters will have no effect (used for heatmaps)
trainer = FCDDTrainer(net, None, None, (None, None), logger, 'fcdd', 8, quantile, 224)
trainer.load(snapshot)
all_anomaly_scores = []
all_anomaly_scores, all_inputs, all_labels = [], [], []
for inputs, labels in loader:
inputs = inputs.cuda()
with torch.no_grad():
outputs = trainer.net(inputs)
anomaly_scores = trainer.anomaly_score(trainer.loss(outputs, inputs, labels, reduce='none'))
anomaly_scores = trainer.net.receptive_upsample(anomaly_scores, reception=True, std=8, cpu=False)
all_anomaly_scores.append(anomaly_scores.cpu())
all_inputs.append(inputs.cpu())
all_labels.append(labels)
all_inputs = torch.cat(all_inputs)
all_labels = torch.cat(all_labels)

# all_anomaly_scores will be a tensor containing pixel-wise anomaly scores for all images
all_anomaly_scores = torch.cat(all_anomaly_scores)

# transform the pixel-wise anomaly scores to sample-wise anomaly scores
print(trainer.reduce_ascore(all_anomaly_scores))

# if there is a logger, create heatmaps and save them to the previously defined path using the logger
if logger is not None:
# show_per_cls defines the maximum number of samples in the heatmaps figures.
# The heatmap_paper_xxx.png figures, which sort the heatmaps by their anomaly score,
# use only up to a third of show_per_cls samples.
trainer.heatmap_generation(all_labels.tolist(), all_anomaly_scores, all_inputs, show_per_cls=1000)

0 comments on commit 9ded37d

Please sign in to comment.