Skip to content

Commit

Permalink
[feature] Hateful Memes Dataset (facebookresearch#115)
Browse files Browse the repository at this point in the history
Summary:
- Dataset zoo for HM
- All of model configurations
- Image processors have been added
- Properly use MMFDataset for the HM dataset
- Some bug fixes in the models
Pull Request resolved: https://github.com/fairinternal/mmf-internal/pull/115

Test Plan:
All of the models present in the paper are tested to be working

Same as the previous commit you can set the data dir to my data dir and
test with it.

Reviewed By: vedanuj

Differential Revision: D21447066

Pulled By: apsdehal

fbshipit-source-id: 8890503e95075ebe33eac02a3be540ff980c6b6b
  • Loading branch information
apsdehal committed May 8, 2020
1 parent b28b6e5 commit ef04cc2
Show file tree
Hide file tree
Showing 30 changed files with 523 additions and 51 deletions.
12 changes: 12 additions & 0 deletions mmf/configs/datasets/hateful_memes/bert.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
dataset_config:
hateful_memes:
processors:
text_processor:
type: bert_tokenizer
params:
tokenizer_config:
type: bert-base-uncased
params:
do_lower_case: true
mask_probability: 0
max_seq_length: 128
62 changes: 62 additions & 0 deletions mmf/configs/datasets/hateful_memes/defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
dataset_config:
hateful_memes:
data_dir: ${env.data_dir}/datasets
depth_first: false
fast_read: false
use_images: true
use_features: false
images:
train:
- hateful_memes/defaults/images/
val:
- hateful_memes/defaults/images/
test:
- hateful_memes/defaults/images/
features:
train:
- hateful_memes/defaults/features/detectron.lmdb
val:
- hateful_memes/defaults/features/detectron.lmdb
test:
- hateful_memes/defaults/features/detectron.lmdb
annotations:
train:
- hateful_memes/defaults/annotations/train.jsonl
val:
- hateful_memes/defaults/annotations/dev.jsonl
test:
- hateful_memes/defaults/annotations/test.jsonl
max_features: 100
processors:
text_processor:
type: vocab
params:
max_length: 14
vocab:
type: intersected
embedding_name: glove.6B.300d
vocab_file: hateful_memes/defaults/extras/vocabs/vocabulary_100k.txt
preprocessor:
type: simple_sentence
params: {}
bbox_processor:
type: bbox
params:
max_length: 50
image_processor:
type: torchvision_transforms
params:
transforms:
- type: Resize
params:
size: [256, 256]
- type: CenterCrop
params:
size: [224, 224]
- ToTensor
- GrayScaleTo3Channels
- type: Normalize
params:
mean: [0.46777044, 0.44531429, 0.40661017]
std: [0.12221994, 0.12145835, 0.14380469]
return_features_info: false
7 changes: 7 additions & 0 deletions mmf/configs/datasets/hateful_memes/with_features.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dataset_config:
hateful_memes:
use_images: false
use_features: true
# Disable this in your config if you do not need features info
# and are running out of memory
return_features_info: true
5 changes: 4 additions & 1 deletion mmf/configs/models/cnn_lstm/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ model_config:
output_dims: [64, 128, 128, 64, 64, 10]
kernel_sizes: [7, 5, 5, 5, 5, 1]
classifier:
input_dim: 450
type: mlp
params:
in_dim: 450
out_dim: 2
22 changes: 22 additions & 0 deletions mmf/configs/zoo/datasets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,25 @@ coco:
- url: mmf://datasets/coco/ocr_en/features/features.tar.gz
file_name: features.tar.gz
hashcode: 8d4d67e878208568934c2c3fb1c304f5073b5a89a25a59938d182e360e23473f


hateful_memes:
defaults:
version: 1.0_2020_05_04
resources:
features:
- url: mmf://datasets/hateful_memes/defaults/features/features.tar.gz
file_name: features.tar.gz
hashcode: 1eb8e5379fcf8f91fda92aa8f5926a536f3788bf07fe0f72ea7efc2d8427f12d
images:
- url: mmf://datasets/hateful_memes/defaults/images/images.tar.gz
file_name: images.tar.gz
hashcode: 6db0c78bdc16bec6a4381d1b2d2a9ac4ac0643d4a329a4562d16c85cfe4b43be
annotations:
- url: mmf://datasets/hateful_memes/defaults/annotations/annotations.tar.gz
file_name: annotations.tar.gz
hashcode: 452486b03083b0912874215a58b3df8227bafb8635904faae4e4ae402baaf13f
extras:
- url: mmf://datasets/hateful_memes/defaults/extras.tar.gz
file_name: extras.tar.gz
hashcode: 1bd88fa36b5c565234cd0bbc20189c85b51a283337bee574db91521be0364739
47 changes: 47 additions & 0 deletions mmf/datasets/builders/hateful_memes/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import warnings

from mmf.common.registry import registry
from mmf.datasets.builders.hateful_memes.dataset import (
HatefulMemesFeaturesDataset,
HatefulMemesImageDataset,
)
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder


@registry.register_builder("hateful_memes")
class HatefulMemesBuilder(MMFDatasetBuilder):
def __init__(
self,
dataset_name="hateful_memes",
dataset_class=HatefulMemesImageDataset,
*args,
**kwargs
):
super().__init__(dataset_name, dataset_class, *args, **kwargs)
self.dataset_class = HatefulMemesImageDataset

@classmethod
def config_path(self):
return "configs/datasets/hateful_memes/defaults.yaml"

def load(self, config, dataset_type, *args, **kwargs):
config = config

if config.use_features:
self.dataset_class = HatefulMemesFeaturesDataset

self.dataset = super().load(config, dataset_type, *args, **kwargs)

return self.dataset

def update_registry_for_model(self, config):
if hasattr(self.dataset, "text_processor"):
registry.register(
self.dataset_name + "_text_vocab_size",
self.dataset.text_processor.get_vocab_size(),
)
registry.register(
self.dataset_name + "_num_final_outputs", 2,
)
76 changes: 76 additions & 0 deletions mmf/datasets/builders/hateful_memes/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import os

import omegaconf
import torch
from PIL import Image
from torchvision import transforms

from mmf.common.sample import Sample
from mmf.datasets.mmf_dataset import MMFDataset
from mmf.utils.general import get_mmf_root


class HatefulMemesFeaturesDataset(MMFDataset):
def __init__(self, config, *args, dataset_name="hateful_memes", **kwargs):
super().__init__(dataset_name, config, *args, **kwargs)
assert (
self._use_features
), "config's 'use_images' must be true to use image dataset"

def preprocess_sample_info(self, sample_info):
image_id = sample_info["id"]
# Add feature_path key for feature_database access
sample_info["feature_path"] = f"{image_id}.npy"
return sample_info

def __getitem__(self, idx):
sample_info = self.annotation_db[idx]
sample_info = self.preprocess_sample_info(sample_info)

current_sample = Sample()

processed_text = self.text_processor({"text": sample_info["text"]})
current_sample.text = processed_text["text"]
if "input_ids" in processed_text:
current_sample.update(processed_text)

current_sample.id = torch.tensor(int(sample_info["id"]), dtype=torch.int)

# Instead of using idx directly here, use sample_info to fetch
# the features as feature_path has been dynamically added
features = self.features_db.get(sample_info)
current_sample.update(features)

current_sample.targets = torch.tensor(sample_info["label"], dtype=torch.long)
return current_sample


class HatefulMemesImageDataset(MMFDataset):
def __init__(self, config, *args, dataset_name="hateful_memes", **kwargs):
super().__init__(dataset_name, config, *args, **kwargs)
assert (
self._use_images
), "config's 'use_images' must be true to use image dataset"

def init_processors(self):
super().init_processors()
# Assign transforms to the image_db
self.image_db.transform = self.image_processor

def __getitem__(self, idx):
sample_info = self.annotation_db[idx]
current_sample = Sample()

processed_text = self.text_processor({"text": sample_info["text"]})
current_sample.text = processed_text["text"]
if "input_ids" in processed_text:
current_sample.update(processed_text)

current_sample.id = torch.tensor(int(sample_info["id"]), dtype=torch.int)

# Get the first image from the set of images returned from the image_db
current_sample.image = self.image_db[idx]["images"][0]
current_sample.targets = torch.tensor(sample_info["label"], dtype=torch.long)
return current_sample
28 changes: 18 additions & 10 deletions mmf/datasets/databases/image_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,23 @@ def __init__(
self.image_key = image_key if image_key else self.image_key
self.is_valid_file = is_valid_file

if self.image_key:
assert (
self.annotation_db is not None
), "Annotation DB should be specified with image key"
@property
def annotation_db(self):
return self._annotation_db

def set_annotation_db(self, annotation_db):
self.annotation_db = annotation_db
@annotation_db.setter
def annotation_db(self, annotation_db):
self._annotation_db = annotation_db

@property
def transform(self):
return self._transform

def set_transforms(self, transform):
@transform.setter
def transform(self, transform):
if isinstance(transform, collections.abc.MutableSequence):
transform = torchvision.Compose(transform)
self.transform = transform
self._transform = transform

def __len__(self):
self._check_annotation_db_present()
Expand All @@ -102,7 +107,7 @@ def _check_annotation_db_present(self):
if not self.annotation_db:
raise AttributeError(
"'annotation_db' must be set for the database to use __getitem__."
+ " Use set_annotation_db."
+ " Use image_database.annotation_db to set it."
)

def get(self, item):
Expand All @@ -128,6 +133,8 @@ def from_path(self, paths):
continue

if not path:
# Create the full path without extension so it can be printed
# for the error
possible_path = os.path.join(
self.base_path, ".".join(image.split(".")[:-1])
)
Expand All @@ -137,8 +144,9 @@ def from_path(self, paths):
possible_path
)
)

path = os.path.join(self.base_path, path)
image = self.open_image(path)

if self.transform:
image = self.transform(image)
loaded_images.append(image)
Expand Down
2 changes: 2 additions & 0 deletions mmf/datasets/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mmf.datasets.processors.bert_processors import MaskedTokenProcessor
from mmf.datasets.processors.image_processors import TorchvisionTransforms
from mmf.datasets.processors.processors import (
BaseProcessor,
BBoxProcessor,
Expand Down Expand Up @@ -28,4 +29,5 @@
"BBoxProcessor",
"CaptionProcessor",
"MaskedTokenProcessor",
"TorchvisionTransforms",
]
Loading

0 comments on commit ef04cc2

Please sign in to comment.