Skip to content

Commit

Permalink
[feat][clip] minor clip update for processor, model and loss
Browse files Browse the repository at this point in the history
Summary:
* make clip_text_processor to be more generic
* avoid using /=
* add temperature to the model

Reviewed By: vedanuj

Differential Revision: D29384557

fbshipit-source-id: 8b16081fc18a0978aee959640fe063bba741c884
  • Loading branch information
Sasha Sheng authored and facebook-github-bot committed Jun 28, 2021
1 parent 397c163 commit 71dfcfc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
26 changes: 20 additions & 6 deletions mmf/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CustomLoss(nn.Module):
from packaging import version
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence

from mmf.utils.distributed import gather_tensor_along_batch_with_backward, get_rank

@dataclass
class LossConfig:
Expand Down Expand Up @@ -757,14 +757,28 @@ def forward(self, sample_list: Dict[str, Tensor], model_output: Dict[str, Tensor
assert (
"embedding_1" in model_output and "embedding_2" in model_output
), "Embedding names must be available before loss calculation"

embedding_1 = model_output["embedding_1"]
embedding_2 = model_output["embedding_2"]

mma = embedding_1 @ embedding_2.T
labels = torch.arange(mma.shape[0], device=mma.device)
loss1 = F.cross_entropy(mma, labels)
loss2 = F.cross_entropy(mma.T, labels)
return (loss1 + loss2) / 2
assert embedding_1.size(0) == embedding_2.size(0), "batch size must match"
per_gpu_batch_size = embedding_1.size(0)

embedding_1_all_gpus = gather_tensor_along_batch_with_backward(embedding_1)
embedding_2_all_gpus = gather_tensor_along_batch_with_backward(embedding_2)

temperature = model_output["temperature"]

logits_1 = torch.matmul(embedding_1, embedding_2_all_gpus.transpose(0, 1))
logits_2 = torch.matmul(embedding_2, embedding_1_all_gpus.transpose(0, 1))
labels = per_gpu_batch_size * get_rank() + torch.arange(
per_gpu_batch_size, device=temperature.device
)

loss_1 = F.cross_entropy(logits_1, labels)
loss_2 = F.cross_entropy(logits_2, labels)

return (loss_1 + loss_2) / 2


@registry.register_loss("mse")
Expand Down
53 changes: 53 additions & 0 deletions mmf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,44 @@
logger = logging.getLogger(__name__)


# copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""

@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)

@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]


class XLAGatherLayer(torch.autograd.Function):
"""
Gather tensors from all TPU workers with support for backward propagation.
"""

@staticmethod
def forward(ctx, x, dim):
ctx.dim = dim
tensor_list = xm.all_gather(x.unsqueeze(dim), dim=dim)
return tensor_list

@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
all_grad_output = xm.all_reduce(xm.REDUCE_SUM, grad_output)
return all_grad_output.select(dim, xm.get_ordinal()), None


def synchronize(message="sync-workers"):
if is_xla():
xm.rendezvous(message)
Expand Down Expand Up @@ -157,6 +195,21 @@ def gather_tensor_along_batch(tensor, dim=0):
return tensor_list


def gather_tensor_along_batch_with_backward(tensor, dim=0):
world_size = get_world_size()

if world_size < 2:
return tensor

if is_xla():
tensor_list = XLAGatherLayer.apply(tensor, dim)
tensor_list = tensor_list.flatten(start_dim=dim, end_dim=dim + 1)
else:
tensor_list = GatherLayer.apply(tensor)
tensor_list = torch.cat(tensor_list, dim=dim)
return tensor_list


def reduce_dict(dictionary):
world_size = get_world_size()
if world_size < 2:
Expand Down

0 comments on commit 71dfcfc

Please sign in to comment.