-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathcaption_dataset.py
114 lines (79 loc) · 3.24 KB
/
caption_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from dataset.utils import pre_caption
class re_train_dataset(Dataset):
def __init__(self, ann_file, transform, image_root, max_words=30):
self.ann = []
for f in ann_file:
self.ann += json.load(open(f,'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.img_ids = {}
n = 0
for ann in self.ann:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image1 = self.transform(image) # jinyu
image2 = self.transform(image) # jinyu
caption = pre_caption(ann['caption'], self.max_words)
return image1, image2, caption, self.img_ids[ann['image_id']]
class re_eval_dataset(Dataset):
def __init__(self, ann_file, transform, image_root, max_words=30):
self.ann = json.load(open(ann_file,'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.ann):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,self.max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.image)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.ann[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
class pretrain_dataset(Dataset):
def __init__(self, ann_file, transform, max_words=30):
self.ann = []
for f in ann_file:
self.ann += json.load(open(f,'r'))
self.transform = transform
self.max_words = max_words
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
if type(ann['caption']) == list:
caption = pre_caption(random.choice(ann['caption']), self.max_words)
else:
caption = pre_caption(ann['caption'], self.max_words)
image = Image.open(ann['image']).convert('RGB')
image1 = self.transform(image)
image2 = self.transform(image)
return image1, image2, caption