Skip to content

Commit

Permalink
edited tta train time
Browse files Browse the repository at this point in the history
  • Loading branch information
marwankefah committed Jun 9, 2022
1 parent a9118eb commit f44f0dd
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion deep_learning_code/configs/mask_rcnn_mix_colab.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ box_detections_per_img=400
min_size=512
max_size=1024

need_label_correction=1
label_correction=1
need_label_correction=1
label_correction_threshold=0.9

box_score_thresh=0.05
Expand Down
2 changes: 1 addition & 1 deletion deep_learning_code/mask_rcnn_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train(configs, snapshot_path):

train_iou, outputs_list_dict = evaluate(configs, epoch_num, initial_weak_labels_data_loader, configs.device,
configs.val_writer,
vis_every_iter=5, use_tta=True)
vis_every_iter=5, use_tta=configs.need_label_correction)

_, _, val_losses_reduced = evaluate(configs, epoch_num, cell_pose_test_dataloader, device=configs.device,
writer=configs.cell_pose_test_writer)
Expand Down
3 changes: 1 addition & 2 deletions deep_learning_code/odach_our/oda.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class TTAWrapper:

def __init__(self, model, tta, scale=[1], nms="wbf", iou_thr=0.5, skip_box_thr=0.5, weights=None, score_thresh=0.1):
self.ttas = self.generate_TTA(tta, scale)
self.model = model # .eval()
self.model = model.cuda() # .eval()
self.score_thresh = score_thresh
# set nms function
# default is weighted box fusion.
Expand All @@ -251,7 +251,6 @@ def generate_TTA(self, tta, scale):

def model_inference(self, img, targets):
with torch.no_grad():

results = self.model(img, targets)
return results

Expand Down
11 changes: 5 additions & 6 deletions deep_learning_code/reference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,6 @@ def coco_evaluate(outputs_list_dict, coco, epoch, writer, train_mask=False):

def correct_labels(configs, weak_label_chrisi_dataset, outputs_list_dict, epoch_num, max_epoch):
if configs.label_correction:
# if the flag is false, then check every time if it needs label correction
# if it is true one time, it will always be true
if not configs.need_label_correction:
configs.need_label_correction = utils.if_update(configs.train_iou_values, epoch_num, n_epoch=max_epoch,
threshold=configs.label_correction_threshold)

# it needs label correction, then output the label correction in a folder and reload it again
# no large cache memory
if configs.need_label_correction:
Expand Down Expand Up @@ -442,6 +436,11 @@ def correct_labels(configs, weak_label_chrisi_dataset, outputs_list_dict, epoch_
else:
logging.info('image with id {} have no output'.format(idx))

# if the flag is false, then check every time if it needs label correction
# if it is true one time, it will always be true
else:
configs.need_label_correction = utils.if_update(configs.train_iou_values, epoch_num, n_epoch=max_epoch,
threshold=configs.label_correction_threshold)

import os

Expand Down

0 comments on commit f44f0dd

Please sign in to comment.