forked from salesforce/BLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvideo_itm_dataset.py
127 lines (102 loc) · 5.09 KB
/
video_itm_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import json
import copy
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class video_itm_eval(Dataset):
def __init__(self, transform, video_root, ann_root, max_words=30, prompt='', max_img_size=224, num_frm=1):
'''
video_root (string): Root directory of video frm (e.g. msrvtt/frms/)
ann_root (string): directory to store the annotation file
'''
filename = 'test_itm.json'
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.video_root = video_root
self.max_words = max_words
self.prompt = prompt
self.max_img_size = max_img_size
self.num_frm = num_frm
self.video_ids = {}
n = 1
for ann in self.annotation:
video_id = ann['video']
if video_id not in self.video_ids.keys():
self.video_ids[video_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
video_path = os.path.join(self.video_root, ann['video'])
if 'ts' in ann:
video = load_video_from_path_decord(video_path, self.transform, height=self.max_img_size, width=self.max_img_size, start_time=ann['ts'][0], end_time=ann['ts'][1], fps=3, num_frm=self.num_frm)
else:
video = load_video_from_path_decord(video_path, self.transform, height=self.max_img_size, width=self.max_img_size, num_frm=self.num_frm)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return video, caption, ann['video_id']
def load_video_from_path_decord(video_path, transform, height=None, width=None, start_time=None, end_time=None, fps=-1,
num_frm=1, frm_sampling_strategy='uniform'):
def expand_video_frms(video_frms):
video_frms_clone = copy.deepcopy(video_frms)
video_frms_ret = []
for i in range(len(video_frms)):
video_frms_ret.append(video_frms[i])
video_frms_ret.append(video_frms_clone[i])
return video_frms_ret
try:
video_frms = os.listdir(video_path)
vlen = len(video_frms)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
if start_idx < end_idx:
video_frms = video_frms[start_idx:end_idx]
# append frames when less
while(len(video_frms) < num_frm):
video_frms = expand_video_frms(video_frms)
vlen = len(video_frms)
start_idx, end_idx = 0, vlen
if frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / num_frm, dtype=int)
# frame_indices = np.linspace(start_idx, end_idx-1, num_frm, dtype=int)
elif frm_sampling_strategy == 'rand':
frame_indices = sorted(random.sample(range(vlen), num_frm))
elif frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), num_frm // 2))
frame_indices = frame_indices_head + frame_indices_tail
elif frm_sampling_strategy == 'tsn':
# divides all frames into num_frm segments of equation duration and then randomly samples one frame from each segment
frame_indices = []
for i in range(num_frm):
frame_indices.append(random.randrange(start_idx + i * vlen // num_frm, start_idx + (i + 1) * vlen // num_frm))
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(frm_sampling_strategy))
# raw_sample_frms = vr.get_batch(frame_indices) # (num_frm, height, weight, channel)
# pre-process frames
images = []
for index in frame_indices:
image_path = os.path.join(video_path, video_frms[index])
images.append(transform(Image.open(image_path).convert("RGB"))) # (num_frm, channel, height, weight)
# images.append(Image.open(image_path).convert("RGB"))
# convert into tensor
if len(images) > 0:
raw_sample_frms = torch.tensor(np.stack(images))
else:
raw_sample_frms = torch.zeros(1)
except Exception as e:
print(e)
return None
# raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) # (num_frm, channel, height, weight)
return raw_sample_frms
if __name__ == '__main__':
video_path = '/cfs/cfs-rmuhzak3/lukenxu/dataset/msvd/frames_fps5_224/eyhzdC936uk_15_27'
video = load_video_from_path_decord(video_path, None, height=224, width=224,
num_frm=1, frm_sampling_strategy='uniform')