-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_extract.py
91 lines (66 loc) · 2.56 KB
/
feature_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Made by Cyto
# _ _
# /> フ
# | _ _l
# /` ミ_xノ
# / |
# / ヽ ノ
# │ | | |
# / ̄| | | |
# | ( ̄ヽ__ヽ_)__)
# \二つ ;
import os
import os.path as osp
import cv2
import numpy as np
import torch
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision import transforms
from tqdm import tqdm
from src.utils import build_model
from src.initialize import extract_initialize
from src.arg_parse import extract_argParse
from src.constant import TRAIN_MEAN, TRAIN_STD, TEST_MEAN, TEST_STD
def __build_transform(subj, train=False):
mean = TRAIN_MEAN[subj] if train else TEST_MEAN[subj]
std = TRAIN_STD[subj] if train else TEST_STD[subj]
return transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
def main(args):
# initialize
args = extract_initialize(args)
print("Initialization ready.")
print(args)
# init transform
tf = __build_transform(args.subject, args.train)
# setup model
model = build_model(args.model, 100, args.pretrained_weights)
model = create_feature_extractor(model, return_nodes=args.layers)
print("Pretrained model loaded from {}".format(args.pretrained_weights))
model.to(args.device)
print("Model initialized. Loaded to <{}> device.".format(args.device))
# get inferred results
print("Start feature extraction")
args.data = osp.join(args.data, args.subject,
"training_split/training_images" if args.train else "test_split/test_images")
print("Data retrieved from: {}".format(args.data))
for img in tqdm(os.listdir(args.data)):
# load img
img_data = cv2.imread(osp.join(args.data, img)
).astype(np.float32)
# convert BGR to RGB
img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
img_data = tf(img_data)
img_data.to(args.device)
# extract
model.eval()
with torch.no_grad():
pred = model(img_data.unsqueeze(0))
for k, v in pred.items():
if not osp.isdir(osp.join(args.save_path, k)):
os.makedirs(osp.join(args.save_path, k))
np.save(osp.join(args.save_path, k, "{}.npy".format(
img.split(".")[0])), v.cpu().numpy().astype(np.float32))
print("Done.")
if __name__ == "__main__":
args = extract_argParse()
main(args)