Skip to content

Commit

Permalink
converting bin targets to hard labels
Browse files Browse the repository at this point in the history
  • Loading branch information
zankner committed Mar 7, 2021
1 parent fc14561 commit 73de1e8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions vit_pytorch/mpp_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ def forward(self, predicted_patches, target, mask):
c=c,
bi=bi)

bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)

predicted_patches = predicted_patches[mask]
discretized_target = discretized_target[mask]
loss = F.mse_loss(predicted_patches, discretized_target)
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
return loss


Expand All @@ -87,7 +92,7 @@ def __init__(self,
max_pixel_val)

# output transformation
self.to_bits = nn.Linear(dim, output_channel_bits * channels)
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))

# vit related dimensions
self.patch_size = patch_size
Expand Down

0 comments on commit 73de1e8

Please sign in to comment.