Skip to content

Commit

Permalink
fix mpp
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 16, 2021
1 parent 53884f5 commit 64a2ef6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 53 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ mpp_trainer = MPP(
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)

def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)

for _ in range(100):
images = sample_unlabelled_images()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.19.5',
version = '0.19.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
94 changes: 43 additions & 51 deletions vit_pytorch/mpp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import math
from functools import reduce

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange, repeat
from einops import rearrange, repeat, reduce

# helpers

def exists(val):
return val is not None

def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob


def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
Expand All @@ -31,74 +31,66 @@ def get_mask_subset_with_prob(patched_input, prob):


class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val, mean, std):
super(MPPLoss, self).__init__()
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val

if mean:
self.mean = torch.tensor(mean).view(-1, 1, 1)
else:
self.mean = None
if std:
self.std = torch.tensor(std).view(-1, 1, 1)
else:
self.std = None
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None

def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)

# un-normalize input
if self.mean is not None and self.std is not None:
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean

# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
target = target.clamp(max = mpv) # clamp just in case
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()

avg_target = target.mean(dim=3)

bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)

bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)

bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')

target_label = torch.sum(bin_mask * discretized_target, dim = -1)

loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss


# main class


class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None):
def __init__(
self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
Expand Down

0 comments on commit 64a2ef6

Please sign in to comment.