-
Notifications
You must be signed in to change notification settings - Fork 32
/
ddad.py
executable file
·84 lines (72 loc) · 3.08 KB
/
ddad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from asyncio import constants
from typing import Any
import torch
from unet import *
from dataset import *
from visualize import *
from anomaly_map import *
from metrics import *
from feature_extractor import *
from reconstruction import *
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2"
class DDAD:
def __init__(self, unet, config) -> None:
self.test_dataset = Dataset_maker(
root= config.data.data_dir,
category=config.data.category,
config = config,
is_train=False,
)
self.testloader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size= config.data.test_batch_size,
shuffle=False,
num_workers= config.model.num_workers,
drop_last=False,
)
self.unet = unet
self.config = config
self.reconstruction = Reconstruction(self.unet, self.config)
self.transform = transforms.Compose([
transforms.CenterCrop((224)),
])
def __call__(self) -> Any:
feature_extractor = domain_adaptation(self.unet, self.config, fine_tune=False)
feature_extractor.eval()
labels_list = []
predictions= []
anomaly_map_list = []
gt_list = []
reconstructed_list = []
forward_list = []
with torch.no_grad():
for input, gt, labels in self.testloader:
input = input.to(self.config.model.device)
x0 = self.reconstruction(input, input, self.config.model.w)[-1]
anomaly_map = heat_map(x0, input, feature_extractor, self.config)
anomaly_map = self.transform(anomaly_map)
gt = self.transform(gt)
forward_list.append(input)
anomaly_map_list.append(anomaly_map)
gt_list.append(gt)
reconstructed_list.append(x0)
for pred, label in zip(anomaly_map, labels):
labels_list.append(0 if label == 'good' else 1)
predictions.append(torch.max(pred).item())
metric = Metric(labels_list, predictions, anomaly_map_list, gt_list, self.config)
metric.optimal_threshold()
if self.config.metrics.auroc:
print('AUROC: ({:.1f},{:.1f})'.format(metric.image_auroc() * 100, metric.pixel_auroc() * 100))
if self.config.metrics.pro:
print('PRO: {:.1f}'.format(metric.pixel_pro() * 100))
if self.config.metrics.misclassifications:
metric.miscalssified()
reconstructed_list = torch.cat(reconstructed_list, dim=0)
forward_list = torch.cat(forward_list, dim=0)
anomaly_map_list = torch.cat(anomaly_map_list, dim=0)
pred_mask = (anomaly_map_list > metric.threshold).float()
gt_list = torch.cat(gt_list, dim=0)
if not os.path.exists('results'):
os.mkdir('results')
if self.config.metrics.visualisation:
visualize(forward_list, reconstructed_list, gt_list, pred_mask, anomaly_map_list, self.config.data.category)