Skip to content

Commit

Permalink
Merge pull request lucidrains#101 from zankner/mpp-fix
Browse files Browse the repository at this point in the history
Mpp fix
  • Loading branch information
lucidrains authored Jun 16, 2021
2 parents 60ad4e2 + a2df363 commit e616b5d
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions vit_pytorch/mpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob):

class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
max_pixel_val, mean, std):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val

if mean:
self.mean = torch.tensor(mean).view(-1, 1, 1)
else:
self.mean = None
if std:
self.std = torch.tensor(std).view(-1, 1, 1)
else:
self.std = None

def forward(self, predicted_patches, target, mask):
# un-normalize input
if self.mean is not None and self.std is not None:
target = target * self.std + self.mean

# reshape target to patches
p = self.patch_size
target = rearrange(target,
Expand All @@ -64,7 +77,6 @@ def forward(self, predicted_patches, target, mask):
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)

predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
Expand All @@ -84,11 +96,13 @@ def __init__(self,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
random_patch_prob=0.5,
mean=None,
std=None):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)

# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
Expand All @@ -102,7 +116,7 @@ def __init__(self,
self.random_patch_prob = random_patch_prob

# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))

def forward(self, input, **kwargs):
transformer = self.transformer
Expand Down

0 comments on commit e616b5d

Please sign in to comment.