forked from facebookresearch/mmf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] Hateful Memes Dataset (facebookresearch#115)
Summary: - Dataset zoo for HM - All of model configurations - Image processors have been added - Properly use MMFDataset for the HM dataset - Some bug fixes in the models Pull Request resolved: https://github.com/fairinternal/mmf-internal/pull/115 Test Plan: All of the models present in the paper are tested to be working Same as the previous commit you can set the data dir to my data dir and test with it. Reviewed By: vedanuj Differential Revision: D21447066 Pulled By: apsdehal fbshipit-source-id: 8890503e95075ebe33eac02a3be540ff980c6b6b
- Loading branch information
Showing
30 changed files
with
523 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
dataset_config: | ||
hateful_memes: | ||
processors: | ||
text_processor: | ||
type: bert_tokenizer | ||
params: | ||
tokenizer_config: | ||
type: bert-base-uncased | ||
params: | ||
do_lower_case: true | ||
mask_probability: 0 | ||
max_seq_length: 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
dataset_config: | ||
hateful_memes: | ||
data_dir: ${env.data_dir}/datasets | ||
depth_first: false | ||
fast_read: false | ||
use_images: true | ||
use_features: false | ||
images: | ||
train: | ||
- hateful_memes/defaults/images/ | ||
val: | ||
- hateful_memes/defaults/images/ | ||
test: | ||
- hateful_memes/defaults/images/ | ||
features: | ||
train: | ||
- hateful_memes/defaults/features/detectron.lmdb | ||
val: | ||
- hateful_memes/defaults/features/detectron.lmdb | ||
test: | ||
- hateful_memes/defaults/features/detectron.lmdb | ||
annotations: | ||
train: | ||
- hateful_memes/defaults/annotations/train.jsonl | ||
val: | ||
- hateful_memes/defaults/annotations/dev.jsonl | ||
test: | ||
- hateful_memes/defaults/annotations/test.jsonl | ||
max_features: 100 | ||
processors: | ||
text_processor: | ||
type: vocab | ||
params: | ||
max_length: 14 | ||
vocab: | ||
type: intersected | ||
embedding_name: glove.6B.300d | ||
vocab_file: hateful_memes/defaults/extras/vocabs/vocabulary_100k.txt | ||
preprocessor: | ||
type: simple_sentence | ||
params: {} | ||
bbox_processor: | ||
type: bbox | ||
params: | ||
max_length: 50 | ||
image_processor: | ||
type: torchvision_transforms | ||
params: | ||
transforms: | ||
- type: Resize | ||
params: | ||
size: [256, 256] | ||
- type: CenterCrop | ||
params: | ||
size: [224, 224] | ||
- ToTensor | ||
- GrayScaleTo3Channels | ||
- type: Normalize | ||
params: | ||
mean: [0.46777044, 0.44531429, 0.40661017] | ||
std: [0.12221994, 0.12145835, 0.14380469] | ||
return_features_info: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
dataset_config: | ||
hateful_memes: | ||
use_images: false | ||
use_features: true | ||
# Disable this in your config if you do not need features info | ||
# and are running out of memory | ||
return_features_info: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
import warnings | ||
|
||
from mmf.common.registry import registry | ||
from mmf.datasets.builders.hateful_memes.dataset import ( | ||
HatefulMemesFeaturesDataset, | ||
HatefulMemesImageDataset, | ||
) | ||
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder | ||
|
||
|
||
@registry.register_builder("hateful_memes") | ||
class HatefulMemesBuilder(MMFDatasetBuilder): | ||
def __init__( | ||
self, | ||
dataset_name="hateful_memes", | ||
dataset_class=HatefulMemesImageDataset, | ||
*args, | ||
**kwargs | ||
): | ||
super().__init__(dataset_name, dataset_class, *args, **kwargs) | ||
self.dataset_class = HatefulMemesImageDataset | ||
|
||
@classmethod | ||
def config_path(self): | ||
return "configs/datasets/hateful_memes/defaults.yaml" | ||
|
||
def load(self, config, dataset_type, *args, **kwargs): | ||
config = config | ||
|
||
if config.use_features: | ||
self.dataset_class = HatefulMemesFeaturesDataset | ||
|
||
self.dataset = super().load(config, dataset_type, *args, **kwargs) | ||
|
||
return self.dataset | ||
|
||
def update_registry_for_model(self, config): | ||
if hasattr(self.dataset, "text_processor"): | ||
registry.register( | ||
self.dataset_name + "_text_vocab_size", | ||
self.dataset.text_processor.get_vocab_size(), | ||
) | ||
registry.register( | ||
self.dataset_name + "_num_final_outputs", 2, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import copy | ||
import os | ||
|
||
import omegaconf | ||
import torch | ||
from PIL import Image | ||
from torchvision import transforms | ||
|
||
from mmf.common.sample import Sample | ||
from mmf.datasets.mmf_dataset import MMFDataset | ||
from mmf.utils.general import get_mmf_root | ||
|
||
|
||
class HatefulMemesFeaturesDataset(MMFDataset): | ||
def __init__(self, config, *args, dataset_name="hateful_memes", **kwargs): | ||
super().__init__(dataset_name, config, *args, **kwargs) | ||
assert ( | ||
self._use_features | ||
), "config's 'use_images' must be true to use image dataset" | ||
|
||
def preprocess_sample_info(self, sample_info): | ||
image_id = sample_info["id"] | ||
# Add feature_path key for feature_database access | ||
sample_info["feature_path"] = f"{image_id}.npy" | ||
return sample_info | ||
|
||
def __getitem__(self, idx): | ||
sample_info = self.annotation_db[idx] | ||
sample_info = self.preprocess_sample_info(sample_info) | ||
|
||
current_sample = Sample() | ||
|
||
processed_text = self.text_processor({"text": sample_info["text"]}) | ||
current_sample.text = processed_text["text"] | ||
if "input_ids" in processed_text: | ||
current_sample.update(processed_text) | ||
|
||
current_sample.id = torch.tensor(int(sample_info["id"]), dtype=torch.int) | ||
|
||
# Instead of using idx directly here, use sample_info to fetch | ||
# the features as feature_path has been dynamically added | ||
features = self.features_db.get(sample_info) | ||
current_sample.update(features) | ||
|
||
current_sample.targets = torch.tensor(sample_info["label"], dtype=torch.long) | ||
return current_sample | ||
|
||
|
||
class HatefulMemesImageDataset(MMFDataset): | ||
def __init__(self, config, *args, dataset_name="hateful_memes", **kwargs): | ||
super().__init__(dataset_name, config, *args, **kwargs) | ||
assert ( | ||
self._use_images | ||
), "config's 'use_images' must be true to use image dataset" | ||
|
||
def init_processors(self): | ||
super().init_processors() | ||
# Assign transforms to the image_db | ||
self.image_db.transform = self.image_processor | ||
|
||
def __getitem__(self, idx): | ||
sample_info = self.annotation_db[idx] | ||
current_sample = Sample() | ||
|
||
processed_text = self.text_processor({"text": sample_info["text"]}) | ||
current_sample.text = processed_text["text"] | ||
if "input_ids" in processed_text: | ||
current_sample.update(processed_text) | ||
|
||
current_sample.id = torch.tensor(int(sample_info["id"]), dtype=torch.int) | ||
|
||
# Get the first image from the set of images returned from the image_db | ||
current_sample.image = self.image_db[idx]["images"][0] | ||
current_sample.targets = torch.tensor(sample_info["label"], dtype=torch.long) | ||
return current_sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.