Skip to content

Commit

Permalink
[feature] Add Late fusion model (facebookresearch#77)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: https://github.com/fairinternal/pythia-internal/pull/77

Reviewed By: apsdehal

Differential Revision: D21107711

Pulled By: vedanuj

fbshipit-source-id: bc1e7bda05ab6bdee8aaaafe95d878116d978d20
  • Loading branch information
vedanuj authored and apsdehal committed May 8, 2020
1 parent 47c62b6 commit 55e9b71
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 9 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
52 changes: 52 additions & 0 deletions mmf/configs/models/fusions/late_fusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Mostly same configs/models/mmbt/defaults.yaml config
model_config:
late_fusion:
# Either pretraining or classification
bert_model_name: bert-base-uncased
direct_features_input: false
freeze_text: false
freeze_modal: false
freeze_complete_base: false
finetune_lr_multiplier: 1
# Dimension of the embedding finally returned by the modal encoder
modal_hidden_size: 2048
# Dimension of the embedding finally returned by the text encoder
text_hidden_size: 768
# Used when classification head is activated
num_labels: 2
# Number of features extracted out per image
num_features: 100

modal_encoder:
type: resnet152
params:
pretrained: true
pool_type: avg
num_output_features: 1

text_encoder:
type: transformer
params:
bert_model_name: ${model_config.late_fusion.bert_model_name}
hidden_size: 768
num_hidden_layers: 12
num_attention_heads: 12
output_attentions: false
output_hidden_states: false

modal_classifier:
type: mlp
params:
in_dim: 2048
out_dim: 2
hidden_dim: 768
num_layers: 0


text_classifier:
type: mlp
params:
in_dim: 768
out_dim: 2
hidden_dim: 768
num_layers: 0
65 changes: 56 additions & 9 deletions mmf/models/concat.py → mmf/models/fusions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from copy import deepcopy

import torch
Expand All @@ -9,7 +10,7 @@
from mmf.utils.modeling import get_bert_configured_parameters


class ConcatBase(MultiModalEncoderBase):
class FusionBase(MultiModalEncoderBase):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

Expand Down Expand Up @@ -49,8 +50,7 @@ def forward(
modal = self.modal(modal, *modal_args, **modal_kwargs)
modal = torch.flatten(modal, start_dim=1)
text = torch.flatten(text, start_dim=1)
out = torch.cat([text, modal], dim=-1)
return out
return text, modal


@registry.register_model("concat_bert")
Expand All @@ -61,10 +61,10 @@ def __init__(self, config, *args, **kwargs):

@classmethod
def config_path(cls):
return "configs/models/concat/concat_bert.yaml"
return "configs/models/fusions/concat_bert.yaml"

def build(self):
self.base = ConcatBase(self.config)
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features
Expand Down Expand Up @@ -105,7 +105,8 @@ def forward(self, sample_list):
else:
modal = sample_list.image

embedding = self.base(text, modal, [mask, segment])
text_embedding, modal_embedding = self.base(text, modal, [mask, segment])
embedding = torch.cat([text_embedding, modal_embedding], dim=-1)
output = {}
output["scores"] = self.classifier(embedding)
return output
Expand All @@ -119,10 +120,10 @@ def __init__(self, config, *args, **kwargs):

@classmethod
def config_path(cls):
return "configs/models/concat/concat_bow.yaml"
return "configs/models/fusions/concat_bow.yaml"

def build(self):
self.base = ConcatBase(self.config)
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features
Expand All @@ -140,7 +141,53 @@ def forward(self, sample_list):
else:
modal = sample_list.image

embedding = self.base(text, modal)
text_embedding, modal_embedding = self.base(text, modal)
embedding = torch.cat([text_embedding, modal_embedding], dim=-1)
output = {}
output["scores"] = self.classifier(embedding)
return output


@registry.register_model("late_fusion")
class LateFusion(BaseModel):
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self._is_direct_features_input = config.direct_features_input

@classmethod
def config_path(cls):
return "configs/models/fusions/late_fusion.yaml"

def build(self):
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features

# As the in_dim is dynamically calculated we need to copy classifier_config
modal_classifier_config = deepcopy(self.config.modal_classifier)
modal_classifier_config.params.in_dim = (
num_features * self.config.modal_hidden_size
)
self.modal_classifier = build_classifier_layer(modal_classifier_config)

text_classifier_config = deepcopy(self.config.text_classifier)
text_classifier_config.params.in_dim = self.config.text_hidden_size
self.text_classifier = build_classifier_layer(text_classifier_config)

def forward(self, sample_list):
text = sample_list.input_ids
mask = sample_list.input_mask
segment = sample_list.segment_ids

if self._is_direct_features_input:
modal = sample_list.image_feature_0
else:
modal = sample_list.image

text_embedding, modal_embedding = self.base(text, modal, [mask, segment])
text = self.text_classifier(text_embedding)
modal = self.modal_classifier(modal_embedding)
output = {}
output["scores"] = (text + modal) / 2
return output
29 changes: 29 additions & 0 deletions projects/others/late_fusion/hateful_memes/defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
includes:
- configs/datasets/hateful_memes/bert.yaml

model_config:
late_fusion:
metrics:
- accuracy
losses:
- type: cross_entropy

scheduler:
type: warmup_linear
params:
num_warmup_steps: 2000
num_training_steps: ${training.max_updates}

optimizer:
type: adam_w
params:
lr: 5e-5
eps: 1e-8

training:
batch_size: 128
lr_scheduler: true
max_updates: 22000
monitored_metric: hateful_memes/accuracy
pretrained_mapping:
base: base

0 comments on commit 55e9b71

Please sign in to comment.