Skip to content

Commit

Permalink
[ImageS2S] Early fusion images (facebookresearch#2797)
Browse files Browse the repository at this point in the history
* early fusion image seq2seq

* black

* add tests

* black

* typing

* add some comments, address eric's concerns
  • Loading branch information
klshuster authored Jun 30, 2020
1 parent aed650b commit f8eb2e9
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 27 deletions.
40 changes: 30 additions & 10 deletions parlai/agents/image_seq2seq/image_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from .modules import ImageSeq2seqModel
from .modules import ImageSeq2seqModel, FusionType
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.dict import DictionaryAgent
from parlai.core.torch_agent import Batch
Expand All @@ -29,7 +29,7 @@ class ImageSeq2seqAgent(TransformerGeneratorAgent, TorchImageAgent):
Combines a transformer generator with images.
"""

def build_model(self) -> ImageSeq2seqModel:
def build_model(self) -> ImageSeq2seqModel: # type: ignore
"""
Override to build appropriate model.
"""
Expand All @@ -55,6 +55,13 @@ def add_cmdline_args(cls, argparser):
recommended=True,
help='if true, include image token (or no image token) for each example',
)
group.add_argument(
'--image-fusion-type',
type=str,
default='late',
choices=[f.value for f in FusionType],
help='which fusion type to use',
)

def build_dictionary(self) -> DictionaryAgent:
"""
Expand All @@ -77,7 +84,7 @@ def _set_text_vec(self, *args, **kwargs) -> dict:
if self.opt.get('include_image_token', False):
# `truncate` is the third arg to this function
truncate = args[2] - 1 if args[2] is not None else None
vec = torch.LongTensor(
vec = torch.LongTensor( # type: ignore
self._check_truncate(obs['text_vec'], truncate, True)
)
token = TOKEN_NO_IMAGE
Expand All @@ -94,11 +101,14 @@ def _dummy_batch(self, batchsize: int, maxlen: int) -> Batch:
Override to include image feats.
"""
b = super()._dummy_batch(batchsize, maxlen)
image = torch.ones(batchsize, self.image_features_dim).cuda()
if self.fp16:
image = image.half()
return Batch(
text_vec=b.text_vec,
label_vec=b.label_vec,
image=torch.ones(batchsize, self.image_features_dim).cuda(),
personalities=torch.ones(batchsize, self.opt.get('embedding_size')).cuda(),
image=image,
personalities=torch.ones(batchsize, self.opt['embedding_size']).cuda(),
)

def batchify_image_features(self, batch: Batch) -> Batch:
Expand All @@ -123,11 +133,12 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
"""
Override for custom loading.
Three reasons:
Reasons:
1. When using an init model without an image encoder
2. When using an init model with only an encoder provided
2. We decide to add segment embeddings after the fact.
3. When using an init model with only an encoder provided
In this case, we may need to add the START token to the state_dict
3. When using an init model without image tokens in the embeddings.
4. When using an init model without image tokens in the embeddings.
This is only the case if the embs differ by 2 in dimension 0
"""
state_dict['encoder.dummy_image_enc'] = self.model.encoder.dummy_image_enc
Expand All @@ -137,7 +148,16 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
for k, v in self.model.encoder.image_encoder.state_dict().items():
state_dict[f'encoder.image_encoder.{k}'] = v

# Case 2 -> Only an Encoder provided
# case 2 -> Segment embeddings in new model
if (
self.opt.get('n_segments', 0) >= 1
and 'encoder.segment_embeddings.weight' not in state_dict
):
state_dict[
'encoder.segment_embeddings.weight'
] = self.model.encoder.segment_embeddings.weight

# Case 3 -> Only an Encoder provided
if not (any('decoder' in state_key for state_key in state_dict)):
for k, v in self.model.decoder.state_dict().items():
state_dict[f'decoder.{k}'] = v
Expand All @@ -150,7 +170,7 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
self.model.load_state_dict(state_dict)
return
except RuntimeError as e:
# Case 3 --> Check for Embedding Diffs. Make sure dims match up
# Case 4 --> Check for Embedding Diffs. Make sure dims match up
embs = state_dict['embeddings.weight']
enc_embs = state_dict['encoder.embeddings.weight']
dec_embs = state_dict['decoder.embeddings.weight']
Expand Down
135 changes: 119 additions & 16 deletions parlai/agents/image_seq2seq/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
Modules for ImageSeq2seqAgent Agent.
"""

from enum import Enum
from functools import reduce
from typing import List, Tuple, Optional, Union

Expand All @@ -16,11 +16,21 @@
from parlai.agents.transformer.modules import (
TransformerGeneratorModel,
TransformerEncoder,
_normalize,
)
from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt


class FusionType(Enum):
"""
Encoder fusion type.
"""

EARLY = 'early'
LATE = 'late'


class ImageSeq2seqModel(TransformerGeneratorModel):
"""
ImageSeq2seqModel.
Expand Down Expand Up @@ -64,6 +74,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent):
output_scaling=opt['output_scaling'],
image_encoder_num_layers=opt['image_encoder_num_layers'],
image_features_dim=opt['image_features_dim'],
fusion=opt['image_fusion_type'],
)


Expand Down Expand Up @@ -97,6 +108,7 @@ def __init__(
image_features_dim=2048,
image_combination_mode='append',
n_image_tokens=1,
fusion='late',
):
"""
Override TransformerEncoder __init__.
Expand All @@ -110,10 +122,13 @@ def __init__(
self.img_dim = image_features_dim
self.image_combination_mode = image_combination_mode
self.n_image_tokens = n_image_tokens
self.fusion = FusionType(fusion)
if self.image_combination_mode == 'add' and self.n_image_tokens > 1:
raise ValueError(
'Image encoding cannot be added to context encoding if there is more than one image token!'
)
if self.fusion is FusionType.EARLY:
assert n_segments == 2, "must use segment embeddings for early fusion"
reduction_type = None # Must pass back unreduced encoding and mask
super().__init__(
n_heads=n_heads,
Expand Down Expand Up @@ -155,7 +170,10 @@ def _build_image_encoder(self):
self.image_encoder = nn.Sequential(*image_layers)

def encode_images(
self, images: Union[List[object], torch.Tensor]
self,
images: Union[List[object], torch.Tensor],
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Encode Images.
Expand All @@ -166,6 +184,10 @@ def encode_images(
:param images:
either a list of objects of length N, of which some maybe be None, or a
tensor of shape (batch size, self.img_dim)
:param positions:
positions for images, of size [len(images) x images.size(1)]
:param segments:
segments for images, of size [len(images)]
:return:
a (image_encoded, image_mask) tuple, where:
Expand All @@ -174,6 +196,19 @@ def encode_images(
self.embedding_size, representing the encoded batch of images
- image_mask is a torch.Tensor of dim N x self.n_image_tokens
"""
if positions is None:
positions = [None for _ in range(len(images))] # type: ignore
else:
positions = self.position_embeddings(positions)
if segments is None:
segments = [None for _ in range(len(images))] # type: ignore
else:
segments = self.segment_embeddings(segments)

# assertions for typing
assert positions is not None
assert segments is not None

image_masks = image_encoded = None
valid_inds = [
i
Expand All @@ -185,38 +220,40 @@ def encode_images(
image_mask_list = []
image_encoded_list = []

valid_imgs = torch.stack([images[i] for i in valid_inds])
valid_imgs = torch.stack([images[i] for i in valid_inds]) # type: ignore
valid_img_enc = self.image_encoder(valid_imgs)

img_num = 0
for i in range(len(images)):
if i in valid_inds:
image_mask_list.append(self.ones_mask)
image_encoded_list.append(valid_img_enc[img_num, :])
image_encoded_list.append(
self._add(
[valid_img_enc[img_num, :], positions[i], segments[i]]
)
)
img_num += 1
else:
image_mask_list.append(~self.ones_mask)
image_encoded_list.append(self.dummy_image_enc)
image_mask_list.append(~self.ones_mask) # type: ignore
image_encoded_list.append(self.dummy_image_enc) # type: ignore

image_masks = torch.stack(image_mask_list)
image_masks = torch.stack(image_mask_list) # type: ignore
image_encoded = torch.stack(image_encoded_list).reshape(
(len(images), self.n_image_tokens, self.embedding_size)
)
assert image_masks.shape == image_encoded.shape[:2]

return image_encoded, image_masks

def forward(
def forward( # type: ignore
self,
src_tokens: Optional[torch.Tensor],
src_tokens: Optional[torch.LongTensor],
image_features: Optional[Union[List[object], torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encode images with context.
Encodes tokens (if given) and images (if given) separately.
Combines via either addition, prepending, or appending the image embedding to
the context embedding.
Encodes tokens (if given) and images (if given) depending on fusion setting.
:param src_tokens:
A bsz x seq_len tensor of src_tokens; possibly None
Expand All @@ -228,6 +265,72 @@ def forward(
A (full_enc, full_mask) tuple, which represents the encoded context
and the mask
"""
if self.fusion is FusionType.LATE:
return self._forward_late_fusion(src_tokens, image_features)
elif self.fusion is FusionType.EARLY:
return self._forward_early_fusion(src_tokens, image_features)
else:
raise RuntimeError(f'Unsupported fusion type: {self.fusion}')

def _forward_early_fusion(
self,
src_tokens: Optional[torch.LongTensor],
image_features: Optional[Union[List[object], torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encode images with context.
Performs early fusion, whereby image embeddings and token embeddings are computed
before passing into the Transformer.
Essentially overrides normal TransformerEncoder forward.
"""
context_tensor = context_mask = None
image_tensor = image_mask = None
if src_tokens is not None and image_features is not None:
assert src_tokens.size(0) == len(image_features)
if src_tokens is not None:
context_tensor, context_mask = self.forward_embedding(
src_tokens, segments=torch.zeros_like(src_tokens) # type: ignore
)
if image_features is not None:
valid_img = [v for v in image_features if isinstance(v, torch.Tensor)][0]
image_tensor, image_mask = self.encode_images(
image_features,
segments=torch.ones( # type: ignore
len(image_features), dtype=torch.long, device=valid_img.device
),
)

# perform early fusion
tensor = self._cat([context_tensor, image_tensor])
mask: torch.BoolTensor = self._cat([context_mask, image_mask]) # type: ignore

# WARNING: Below follows the rest of TransformerEncoder.forward
if self.variant == 'xlm':
tensor = _normalize(tensor, self.norm_embeddings)
# --dropout on the embeddings
tensor = self.dropout(tensor)
tensor *= mask.unsqueeze(-1).type_as(tensor)
# apply transformer layers
tensor = self.forward_layers(tensor, mask)
if self.variant == 'prelayernorm':
tensor = _normalize(tensor, self.norm_embeddings)
# reduce output
tensor, out_mask = self.reduce_output(tensor, mask)
return tensor, out_mask

def _forward_late_fusion(
self,
src_tokens: Optional[torch.LongTensor],
image_features: Optional[Union[List[object], torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encode images with context.
Encodes tokens (if given) and images (if given) separately. Combines via either
addition, prepending, or appending the image embedding to the context embedding.
"""
context_encoded = context_mask = None
image_encoded = extra_masks = None
if src_tokens is not None and image_features is not None:
Expand Down Expand Up @@ -278,8 +381,8 @@ def _add(self, tensors: List[Optional[torch.Tensor]]) -> torch.Tensor:
:return:
The result of adding all non-null objects in tensors
"""
tensors = [t for t in tensors if t is not None]
return reduce(lambda a, b: a + b, tensors)
non_null_tensors: List[torch.Tensor] = [t for t in tensors if t is not None]
return reduce(lambda a, b: a + b, non_null_tensors)

def _cat(self, tensors: List[Optional[torch.Tensor]]) -> torch.Tensor:
"""
Expand All @@ -293,8 +396,8 @@ def _cat(self, tensors: List[Optional[torch.Tensor]]) -> torch.Tensor:
:return:
The result of concatenating all non-null objects in tensors
"""
tensors = [t for t in tensors if t is not None]
return torch.cat([t for t in tensors], dim=1)
non_null_tensors: List[torch.Tensor] = [t for t in tensors if t is not None]
return torch.cat([t for t in non_null_tensors], dim=1)

def _fix_for_fp16(
self, full_enc: torch.Tensor, full_mask: Optional[torch.Tensor]
Expand Down
Loading

0 comments on commit f8eb2e9

Please sign in to comment.