forked from praveenVnktsh/Fast-Road-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
112 lines (81 loc) · 2.78 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import print_function
from models.litmodel import LitModel
from models.fcn32s import FCN32s
from models.vgg import VGGNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models
from torchvision.models.vgg import VGG
import pytorch_lightning as pl
from torch.optim.rmsprop import RMSprop
from dataloader import CustomDataset, lit_custom_data
from pytorch_lightning import loggers
from configs import Configs
from icecream import ic
import os
import cv2
import numpy as np
import timeit
avg_iou = []
avg_fps = []
def test(iterator):
verbose = True
model.freeze()
model.eval()
for i, batch in enumerate(iterator):
# if i % 2 == 0:
# continue
image = batch['input'].cuda()
target = batch['target']
real = batch['real']
start = timeit.default_timer()
out = model(image)
end = timeit.default_timer()
t = target.detach().cpu().squeeze().numpy()
o = out.detach().cpu().squeeze().numpy() > 0.5
iou = np.sum(np.bitwise_and(t.astype(bool), o.astype(bool))) / \
np.sum(np.bitwise_or(t.astype(bool), o.astype(bool)))
avg_iou.append(iou)
avg_fps.append(1/(end-start))
print(
f"frametime: {(end-start)*1000}ms, iou: {iou} avg: {np.mean(np.array(avg_iou))}, avgfps:{np.mean(np.array(avg_fps))}")
if verbose:
# Orignal Image
img = np.transpose(
(real*255)[0].numpy().astype("uint8"), (1, 2, 0))
# Mask Image
overlay = np.zeros(img.squeeze().shape)
# Model Output
overlay[:, :, 2] = o
# Ground truth
overlay[:, :, 0] = t
# Converting to int8
overlay = (overlay*255).astype("uint8")
# Overlaying it over image
out = cv2.addWeighted(img, 1, overlay, 0.5, 0)
out = cv2.resize(out, (500, 500), interpolation=cv2.INTER_AREA)
# Displaying image
cv2.imshow('Out', out)
if cv2.waitKey(25) & 0xFF == ord('q'):
break
cv2.destroyAllWindows()
print("||STATS||")
ic(np.mean(np.array(avg_fps)))
ic(np.mean(np.array(avg_iou)))
# class LightningMNISTClassifier(pl.LightningModule):
if __name__ == '__main__':
hparams = {
'lr': 0.01
}
model = LitModel.load_from_checkpoint(
"lightning_logs\Resnet_LSTM_FCN_2\checkpoints\epoch=119-step=3839.ckpt").cuda()
dataset = lit_custom_data()
# trainer = pl.Trainer(gpus=1, max_epochs=120)
# trainer.fit(model, dataset)
dataset.setup()
test(dataset.test_dataloader())
# trainer = pl.Trainer(gpus=1, )
# trainer.test(model, dataset.test_dataloader())
print("hello")