-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_custom.py
executable file
·41 lines (35 loc) · 1.85 KB
/
inference_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from pathlib import Path
import torch
import torchvision
from torchvision import models, transforms, utils
import argparse
from utils import *
from qatm_pytorch_custom import CreateModel, ImageDataset,nms_multi,nms,run_multi_sample,plot_result_multi
# +
# import functions and classes from qatm_pytorch.py
print("import qatm_pytorch.py...")
import ast
import types
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='QATM Pytorch Implementation')
parser.add_argument('--cuda', action='store_true')
parser.add_argument('-s', '--sample_image', default='data/sample/3.jpg')
# parser.add_argument('-s', '--sample_image', default='/home/mayank_sati/Desktop/git/2/AI/QATM/data/mayank_combo/1568265119442921349_369192.472349_4321434.38009.jpg')
# parser.add_argument('-t', '--template_images_dir', default='/home/mayank_sati/Desktop/qatm_data/template/')
parser.add_argument('-t', '--template_images_dir', default='data/cust_template/')
# parser.add_argument('-t', '--template_images_dir', default='data/template_2/')
parser.add_argument('--alpha', type=float, default=25)
parser.add_argument('--thresh_csv', type=str, default='thresh_template.csv')
args = parser.parse_args()
template_dir = args.template_images_dir
image_path = args.sample_image
dataset = ImageDataset(Path(template_dir), image_path, thresh_csv='thresh_template.csv')
print("define model...")
model = CreateModel(model=models.vgg19(pretrained=True).features, alpha=args.alpha, use_cuda=args.cuda)
print("calculate score...")
scores, w_array, h_array, thresh_list = run_multi_sample(model, dataset)
print("nms...")
boxes, indices = nms_multi(scores, w_array, h_array, thresh_list)
_ = plot_result_multi(dataset.image_raw, boxes, indices, show=True, save_name='result.png')
print("result.png was saved")