Skip to content

Commit

Permalink
remove some unnecessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
luoxiangde authored and luoxiangde committed Jul 22, 2021
1 parent 9c36a7f commit 0fc63df
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 25 deletions.
31 changes: 15 additions & 16 deletions code/test_urpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,18 @@ def Inference(FLAGS):


if __name__ == '__main__':
for exp_id in ["GTV_Uncertain_Aware_Deep_Supervised_V2_90_labeled"]:
print(exp_id)
model = os.listdir(
"/media/xdluo/ssd/Projects/UADS/data_ratio_model/{}".format(exp_id))[0]
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='../data/WestChina', help='Name of Experiment')
parser.add_argument('--exp', type=str,
default=exp_id, help='experiment_name')
parser.add_argument('--model', type=str,
default=model, help='model_name')
FLAGS = parser.parse_args()

metric = Inference(FLAGS)
print(metric)
print((metric[0] + metric[1]) / 2)

model = os.listdir(
"/media/xdluo/ssd/Projects/UADS/data_ratio_model/URPC")
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='../data/WestChina', help='Name of Experiment')
parser.add_argument('--exp', type=str,
default="URPC", help='experiment_name')
parser.add_argument('--model', type=str,
default=model, help='model_name')
FLAGS = parser.parse_args()

metric = Inference(FLAGS)
print(metric)

2 changes: 1 addition & 1 deletion code/train_adversarial_network_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def worker_init_fn(worker_id):
outputs_soft = torch.softmax(outputs, dim=1)

loss_ce = ce_loss(outputs[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_dice = dice_loss(
outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
supervised_loss = 0.5 * (loss_dice + loss_ce)
Expand Down
2 changes: 1 addition & 1 deletion code/train_cross_pseudo_supervision_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def worker_init_fn(worker_id):

writer.add_scalar('info/model2_val_dice_score',
avg_metric2[0, 0], iter_num)
writer.add_scalar('info/model1_val_hd95',
writer.add_scalar('info/model2_val_hd95',
avg_metric2[0, 1], iter_num)
logging.info(
'iteration %d : model2_dice_score : %f model2_hd95 : %f' % (
Expand Down
2 changes: 1 addition & 1 deletion code/train_entropy_minimization_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def worker_init_fn(worker_id):
outputs_soft = torch.softmax(outputs, dim=1)

loss_ce = ce_loss(outputs[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_dice = dice_loss(
outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
supervised_loss = 0.5 * (loss_dice + loss_ce)
Expand Down
2 changes: 1 addition & 1 deletion code/train_fully_supervised_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def worker_init_fn(worker_id):
outputs = model(volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)

loss_ce = ce_loss(outputs, label_batch[:])
loss_ce = ce_loss(outputs, label_batch)
loss_dice = dice_loss(outputs_soft, label_batch.unsqueeze(1))
loss = 0.5 * (loss_dice + loss_ce)
optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion code/train_uncertainty_aware_mean_teacher_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def worker_init_fn(worker_id):
torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True)

loss_ce = ce_loss(outputs[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_dice = dice_loss(
outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
supervised_loss = 0.5 * (loss_dice + loss_ce)
Expand Down
8 changes: 4 additions & 4 deletions code/train_uncertainty_rectified_pyramid_consistency_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def worker_init_fn(worker_id):
outputs_aux4_soft = torch.softmax(outputs_aux4, dim=1)

loss_ce_aux1 = ce_loss(outputs_aux1[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_ce_aux2 = ce_loss(outputs_aux2[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_ce_aux3 = ce_loss(outputs_aux3[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])
loss_ce_aux4 = ce_loss(outputs_aux4[:args.labeled_bs],
label_batch[:args.labeled_bs][:])
label_batch[:args.labeled_bs])

loss_dice_aux1 = dice_loss(
outputs_aux1_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
Expand Down

0 comments on commit 0fc63df

Please sign in to comment.