-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbase_dataset.py
254 lines (218 loc) · 9.23 KB
/
base_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import random
import torch
import io
import pyarrow as pa
import os
from PIL import Image
from ..transforms import keys_to_transforms
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self,
data_dir: str,
transform_keys: list,
image_size: int,
names: list,
text_column_name: str = "",
remove_duplicate=True,
max_text_len=40,
draw_false_image=0,
draw_false_text=0,
image_only=False,
tokenizer=None,
):
"""
data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data
transform_keys : keys for generating augmented views of images
text_column_name : pyarrow table column name that has list of strings as elements
"""
assert len(transform_keys) >= 1
super().__init__()
self.transforms = keys_to_transforms(transform_keys, size=image_size)
self.clip_transform = False
for transform_key in transform_keys:
if 'clip' in transform_key:
self.clip_transform = True
break
self.text_column_name = text_column_name
self.names = names
self.max_text_len = max_text_len
self.draw_false_image = draw_false_image
self.draw_false_text = draw_false_text
self.image_only = image_only
self.data_dir = data_dir
if len(names) != 0:
tables = [
pa.ipc.RecordBatchFileReader(
pa.memory_map(f"{data_dir}/{name}.arrow", "r")
).read_all()
for name in names
if os.path.isfile(f"{data_dir}/{name}.arrow")
]
self.table_names = list()
for i, name in enumerate(names):
self.table_names += [name] * len(tables[i])
self.table = pa.concat_tables(tables, promote=True)
if text_column_name != "":
self.text_column_name = text_column_name
self.all_texts = self.table[text_column_name].to_pandas().tolist()
if type(self.all_texts[0][0]) == str:
self.all_texts = (
[list(set(texts)) for texts in self.all_texts]
if remove_duplicate
else self.all_texts
)
else: #snli
self.all_texts = (
[[t[1].strip() for t in texts] for texts in self.all_texts]
)
else:
self.all_texts = list()
else:
self.all_texts = list()
self.index_mapper = dict()
if text_column_name != "" and not self.image_only:
j = 0
for i, texts in enumerate(self.all_texts):
for _j in range(len(texts)):
self.index_mapper[j] = (i, _j)
j += 1
else:
for i in range(len(self.table)):
self.index_mapper[i] = (i, None)
@property
def corpus(self):
return [text for texts in self.all_texts for text in texts]
def __len__(self):
return len(self.index_mapper)
def get_raw_image(self, index, image_key="image"):
index, caption_index = self.index_mapper[index]
image_bytes = io.BytesIO(self.table[image_key][index].as_py())
image_bytes.seek(0)
if self.clip_transform:
return Image.open(image_bytes).convert("RGBA")
else:
return Image.open(image_bytes).convert("RGB")
def get_image(self, index, image_key="image"):
image = self.get_raw_image(index, image_key=image_key)
image_tensor = [tr(image) for tr in self.transforms]
return {
"image": image_tensor,
"img_index": self.index_mapper[index][0],
"cap_index": self.index_mapper[index][1],
"raw_index": index,
}
def get_false_image(self, rep, image_key="image"):
random_index = random.randint(0, len(self.index_mapper) - 1)
image = self.get_raw_image(random_index, image_key=image_key)
image_tensor = [tr(image) for tr in self.transforms]
return {f"false_image_{rep}": image_tensor}
def get_text(self, raw_index):
index, caption_index = self.index_mapper[raw_index]
text = self.all_texts[index][caption_index]
encoding = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_text_len,
return_special_tokens_mask=True,
)
return {
"text": (text, encoding),
"img_index": index,
"cap_index": caption_index,
"raw_index": raw_index,
}
def get_false_text(self, rep):
random_index = random.randint(0, len(self.index_mapper) - 1)
index, caption_index = self.index_mapper[random_index]
text = self.all_texts[index][caption_index]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_text_len,
return_special_tokens_mask=True,
)
return {f"false_text_{rep}": (text, encoding)}
def get_suite(self, index):
result = None
while result is None:
try:
ret = dict()
ret.update(self.get_image(index))
if not self.image_only:
txt = self.get_text(index)
ret.update({"replica": True if txt["cap_index"] > 0 else False})
ret.update(txt)
for i in range(self.draw_false_image):
ret.update(self.get_false_image(i))
for i in range(self.draw_false_text):
ret.update(self.get_false_text(i))
result = True
except Exception as e:
print(f"Error while read file idx {index} in {self.names[0]} -> {e}")
index = random.randint(0, len(self.index_mapper) - 1)
return ret
def collate(self, batch, mlm_collator):
batch_size = len(batch)
keys = set([key for b in batch for key in b.keys()])
dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
img_keys = [k for k in list(dict_batch.keys()) if "image" in k]
img_sizes = list()
for img_key in img_keys:
img = dict_batch[img_key]
img_sizes += [ii.shape for i in img if i is not None for ii in i]
for size in img_sizes:
assert (
len(size) == 3
), f"Collate error, an image should be in shape of (3, H, W), instead of given {size}"
if len(img_keys) != 0:
max_height = max([i[1] for i in img_sizes])
max_width = max([i[2] for i in img_sizes])
for img_key in img_keys:
img = dict_batch[img_key]
view_size = len(img[0])
new_images = [
torch.zeros(batch_size, 3, max_height, max_width)
for _ in range(view_size)
]
for bi in range(batch_size):
orig_batch = img[bi]
for vi in range(view_size):
if orig_batch is None:
new_images[vi][bi] = None
else:
orig = img[bi][vi]
new_images[vi][bi, :, : orig.shape[1], : orig.shape[2]] = orig
dict_batch[img_key] = new_images
txt_keys = [k for k in list(dict_batch.keys()) if "text" in k]
if len(txt_keys) != 0:
texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys]
encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys]
draw_text_len = len(encodings)
flatten_encodings = [e for encoding in encodings for e in encoding]
flatten_mlms = mlm_collator(flatten_encodings)
for i, txt_key in enumerate(txt_keys):
texts, encodings = (
[d[0] for d in dict_batch[txt_key]],
[d[1] for d in dict_batch[txt_key]],
)
mlm_ids, mlm_labels = (
flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)],
flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)],
)
input_ids = torch.zeros_like(mlm_ids)
attention_mask = torch.zeros_like(mlm_ids)
for _i, encoding in enumerate(encodings):
_input_ids, _attention_mask = (
torch.tensor(encoding["input_ids"]),
torch.tensor(encoding["attention_mask"]),
)
input_ids[_i, : len(_input_ids)] = _input_ids
attention_mask[_i, : len(_attention_mask)] = _attention_mask
dict_batch[txt_key] = texts
dict_batch[f"{txt_key}_ids"] = input_ids
dict_batch[f"{txt_key}_labels"] = torch.full_like(input_ids, -100)
dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids
dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels
dict_batch[f"{txt_key}_masks"] = attention_mask
return dict_batch