Skip to content

Commit

Permalink
Revert "Update bi_tuning.py"
Browse files Browse the repository at this point in the history
This reverts commit c82be86.
  • Loading branch information
thucbx99 committed Oct 26, 2021
1 parent 7293ff2 commit 9eade24
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions talib/finetune/bi_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ class Classifier(ClassifierBase):
- hn: (minibatch, `features_dim`)
"""

def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=128, finetune=True, pool_layer=None):
head = nn.Linear(backbone.out_features, num_classes)
head.weight.data.normal_(0, 0.01)
head.bias.data.fill_(0.0)
super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,
pool_layer=pool_layer)
super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune, pool_layer=pool_layer)
self.projector = nn.Linear(backbone.out_features, projection_dim)
self.projection_dim = projection_dim

Expand Down Expand Up @@ -77,7 +75,7 @@ def get_parameters(self, base_lr=1.0):
return params


class BiTuning(nn.Module):
class Bituning(nn.Module):
"""
Bi-Tuning Module in `Bi-tuning of Pre-trained Representations <https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29>`_.
Expand Down Expand Up @@ -108,25 +106,24 @@ class BiTuning(nn.Module):
- logits_y: (minibatch, 1 + `num_classes` x `K`, `num_classes`)
- labels_c: (minibatch, 1 + `num_classes` x `K`)
"""

def __init__(self, encoder_q: Classifier, encoder_k: Classifier, num_classes, K=40, m=0.999, T=0.07):
super(BiTuning, self).__init__()
super(Bituning, self).__init__()
self.K = K
self.m = m
self.T = T
self.num_classes = num_classes

# create the encoders
# num_classes is the output fc dimension
self.encoder_q = encoder_q
self.encoder_q =encoder_q
self.encoder_k = encoder_k

for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient

# create the queue
self.register_buffer("queue_h", torch.randn(encoder_q.features_dim + 1, num_classes, K))
self.register_buffer("queue_h", torch.randn(encoder_q.features_dim+1, num_classes, K))
self.register_buffer("queue_z", torch.randn(encoder_q.projection_dim, num_classes, K))
self.queue_h = normalize(self.queue_h, dim=0)
self.queue_z = normalize(self.queue_z, dim=0)
Expand All @@ -148,8 +145,8 @@ def _dequeue_and_enqueue(self, h, z, label):

ptr = int(self.queue_ptr[label])
# replace the keys at ptr (dequeue and enqueue)
self.queue_h[:, label, ptr: ptr + batch_size] = h.T
self.queue_z[:, label, ptr: ptr + batch_size] = z.T
self.queue_h[:, label, ptr: ptr+batch_size] = h.T
self.queue_z[:, label, ptr: ptr+batch_size] = z.T

# move pointer
self.queue_ptr[label] = (ptr + batch_size) % self.K
Expand Down Expand Up @@ -177,16 +174,15 @@ def forward(self, im_q, im_k, labels):
for i in range(batch_size):
c = labels[i]
pos_samples = queue_z[:, c, :] # D x K
neg_samples = torch.cat([queue_z[:, 0: c, :], queue_z[:, c + 1:, :]], dim=1).flatten(
start_dim=1) # D x ((C-1)xK)
ith_pos = torch.einsum('nc,ck->nk', [z_q[i: i + 1], pos_samples]) # 1 x D
ith_neg = torch.einsum('nc,ck->nk', [z_q[i: i + 1], neg_samples]) # 1 x ((C-1)xK)
neg_samples = torch.cat([queue_z[:, 0: c, :], queue_z[:, c+1:, :]], dim=1).flatten(start_dim=1) # D x ((C-1)xK)
ith_pos = torch.einsum('nc,ck->nk', [z_q[i: i+1], pos_samples]) # 1 x D
ith_neg = torch.einsum('nc,ck->nk', [z_q[i: i+1], neg_samples]) # 1 x ((C-1)xK)
logits_z_pos = torch.cat((logits_z_pos, ith_pos), dim=0)
logits_z_neg = torch.cat((logits_z_neg, ith_neg), dim=0)

self._dequeue_and_enqueue(h_k[i:i + 1], z_k[i:i + 1], labels[i])
self._dequeue_and_enqueue(h_k[i:i+1], z_k[i:i+1], labels[i])

logits_z = torch.cat([logits_z_cur, logits_z_pos, logits_z_neg], dim=1) # Nx(1+C*K)
logits_z = torch.cat([logits_z_cur, logits_z_pos, logits_z_neg], dim=1) # Nx(1+C*K)

# apply temperature
logits_z /= self.T
Expand All @@ -199,18 +195,16 @@ def forward(self, im_q, im_k, labels):
# current positive logits: Nx1
logits_y_cur = torch.einsum('nk,kc->nc', [h_q, w.T]) # N x C
queue_y = self.queue_h.clone().detach().to(device).flatten(start_dim=1).T # (C * K) x F
logits_y_queue = torch.einsum('nk,kc->nc', [queue_y, w.T]).reshape(self.num_classes, -1,
self.num_classes) # C x K x C
logits_y_queue = torch.einsum('nk,kc->nc', [queue_y, w.T]).reshape(self.num_classes, -1, self.num_classes) # C x K x C

logits_y = torch.Tensor([]).to(device)

for i in range(batch_size):
c = labels[i]
# calculate the ith sample in the batch
cur_sample = logits_y_cur[i:i + 1, c] # 1
cur_sample = logits_y_cur[i:i+1, c] # 1
pos_samples = logits_y_queue[c, :, c] # K
neg_samples = torch.cat([logits_y_queue[0: c, :, c], logits_y_queue[c + 1:, :, c]], dim=0).view(
-1) # (C-1)*K
neg_samples = torch.cat([logits_y_queue[0: c, :, c], logits_y_queue[c + 1:, :, c]], dim=0).view(-1) # (C-1)*K

ith = torch.cat([cur_sample, pos_samples, neg_samples]) # 1+C*K
logits_y = torch.cat([logits_y, ith.unsqueeze(dim=0)], dim=0)
Expand All @@ -222,3 +216,4 @@ def forward(self, im_q, im_k, labels):
labels_c = torch.zeros([batch_size, self.K * self.num_classes + 1]).to(device)
labels_c[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))
return y_q, logits_z, logits_y, labels_c

0 comments on commit 9eade24

Please sign in to comment.