Skip to content

Commit

Permalink
fix bugs and norm only for cl feature
Browse files Browse the repository at this point in the history
  • Loading branch information
tangminji committed Aug 3, 2023
1 parent 0bdfd14 commit 8f492b3
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion PreResNet_rours.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ def forward(self, x, filter=None, lin=0, lout=5):
elif filter == 'dct':
out += self.filter_DCT(out)

logits = self.linear(out)
# Feature for CL
if self.norm:
out = F.normalize(out, dim=1)
logits = self.linear(out)

return logits, out

Expand Down
2 changes: 1 addition & 1 deletion loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, args, contrast_mode='all', base_temperature=0.07):
# Linear
# gamma_schedule = torch.linspace(0, 1, args.warm_up)
# Exponent
gamma_schedule[:args.warm_up] = torch.logspace(-args.warm_up, 0, args.warmup, np.e)
gamma_schedule[:args.warm_up] = torch.logspace(-args.warm_up, 0, args.warm_up, np.e)
self.gamma = gamma_schedule * args.gamma
self.eta = args.eta
self.warmup = args.warm_up
Expand Down
2 changes: 2 additions & 0 deletions main_ce1.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def main(args, params={}):
params = params['best']
args.nrun = True
args = update_args(params)
# TODO
args.record = True
res = main(args, params=params)
# TODO
if args.out_tmp:
Expand Down
2 changes: 2 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def train_ours(args, model, loader, optimizer, epoch, scheduler, criterion, net_
log_value('train_detail/delta_smooth/std', delta_smooth.std(), step=epoch)
log_value('train_detail/net_record/avg', net_record.mean(), step=epoch)
log_value('train_detail/net_record/std', net_record.std(), step=epoch)
if args.model_type == 'ours_cl':
log_value('train_detail/gamma', criterion.gamma[epoch], step=epoch)
# Print and log stats for the epoch
log_value('train/loss', train_loss.avg, step=epoch)
log(args.logpath, 'Time for Train-Epoch-{}/{}:{:.1f}s Acc:{}, Loss:{}\n'.format(epoch, args.n_epoch, time.time() - t0, correct.avg, train_loss.avg))
Expand Down

0 comments on commit 8f492b3

Please sign in to comment.