Skip to content

Commit

Permalink
some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
iliasprc committed Apr 23, 2020
1 parent 647f028 commit 3f8083b
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 63 deletions.
3 changes: 2 additions & 1 deletion lib/losses3D/BaseClass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn as nn

from lib.losses3D.basic import expand_as_one_hot


Expand Down Expand Up @@ -47,7 +48,7 @@ def forward(self, input, target):

if self.skip_index_after is not None:
target = self.skip_target_channels(target, self.skip_index_after)

# print(input.size(),target.size())
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
# get probabilities from logits
input = self.normalization(input)
Expand Down
2 changes: 1 addition & 1 deletion lib/medzoo/BaseModelClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def save_checkpoint(self,
if name is None:
name = "{}_{}_epoch.pth".format(
os.path.basename(directory), # netD or netG
epoch)
'last')

torch.save(ckpt_dict, os.path.join(directory, name))
if self.best_loss>loss:
Expand Down
10 changes: 5 additions & 5 deletions lib/medzoo/HyperDensenet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from torchsummary import summary

from lib.medzoo.BaseModelClass import BaseModel

"""
Expand Down Expand Up @@ -412,14 +413,13 @@ def forward(self, input):
return self.final(y)

def test(self,device='cpu'):

input_tensor = torch.rand(1, 2, 22, 22, 22)
ideal_out = torch.rand(1, self.num_classes, 22, 22, 22)
out = self.forward(input_tensor)
#assert ideal_out.shape == out.shape
summary(self.to(torch.device(device)), (2, 22, 22, 22),device=device)
torchsummaryX.summary(self,input_tensor.to(device))
print("HyperDenseNet test is complete",out.shape)
# assert ideal_out.shape == out.shape
# summary(self.to(torch.device(device)), (2, 22, 22, 22),device=device)
# torchsummaryX.summary(self,input_tensor.to(device))
print("HyperDenseNet test is complete", out.shape)


class HyperDenseNet(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions lib/medzoo/ResNet3D_VAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,5 @@ def testVAE():


#m = ResNet3dVAE(max_conv_channels=128, dim=(32, 32, 32), modalities=2, classes=4)
#m.test()
testVAE()


# m.test()
# testVAE()
11 changes: 6 additions & 5 deletions lib/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def training(self):
## TODO WRITER SCALARS END OF EPOCH

## TODO SAVE CHECKPOINT
val_loss = str(self.writer.data['val']['loss'] / self.writer.data['val']['count'])
val_loss = self.writer.data['val']['loss'] / self.writer.data['val']['count']
if self.args.save != None:
self.model.save_checkpoint(self.args.save,
epoch, val_loss,
Expand Down Expand Up @@ -64,8 +64,9 @@ def train_epoch(self, epoch):

self.writer.update_scores(batch_idx, loss_dice.item(), per_ch_score, 'train',
epoch * self.len_epoch + batch_idx)
# TODO display terminal statistics per batch or iteration steps
self.writer.display_terminal(partial_epoch, epoch, 'train')
## TODO display terminal statistics per batch or iteration steps
if (batch_idx % 100 == 0):
self.writer.display_terminal(partial_epoch, epoch, 'train')

# END OF EPOCH DISPLAY
self.writer.display_terminal(self.len_epoch, epoch, mode='train', summary=True)
Expand All @@ -81,7 +82,7 @@ def validate_epoch(self, epoch):
output = self.model(input_tensor)
loss, per_ch_score = self.criterion(output, target)

self.writer.update_scores(batch_idx, loss.item(), per_ch_score, 'train',
self.writer.update_scores(batch_idx, loss.item(), per_ch_score, 'val',
epoch * self.len_epoch + batch_idx)

self.writer.display_terminal(len(self.valid_data_loader), epoch, mode='train', summary=True)
self.writer.display_terminal(len(self.valid_data_loader), epoch, mode='val', summary=True)
28 changes: 25 additions & 3 deletions lib/utils/general.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import json
import os
import random
import shutil
import time

import torch
import os
import random, time
import numpy as np
import torch
import torch.backends.cudnn as cudnn

def reproducibility(args,seed):
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
np.random.seed(seed)
cudnn.deterministic = True
# FOR FASTER GPU TRAINING WHEN INPUT SIZE DOESN'T VARY
# LET'S TEST IT
cudnn.benchmark = True


def save_arguments(args, path):
with open(path + '/training_arguments.txt', 'w') as f:
json.dump(args.__dict__, f, indent=2)
f.close()


def datestr():
now = time.gmtime()
return '{:02}_{:02}___{:02}_{:02}'.format(now.tm_mday, now.tm_mon, now.tm_hour, now.tm_min)
return '{:02}_{:02}___{:02}_{:02}'.format(now.tm_mday, now.tm_mon, now.tm_hour, now.tm_min)


def shuffle_lists(a, b, seed=777):
Expand Down
44 changes: 31 additions & 13 deletions lib/visual3D_temp/BaseWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, args):
"val": dict((label, 0.0) for label in self.label_names)}
self.data['train']['loss'] = 0.0
self.data['val']['loss'] = 0.0
self.data['train']['count'] = 1
self.data['val']['count'] = 1
self.data['train']['count'] = 1.0
self.data['val']['count'] = 1.0

self.data['train']['dsc'] = 0.0
self.data['val']['dsc'] = 0.0
Expand All @@ -54,21 +54,27 @@ def display_terminal(self, iter, epoch, mode='train', summary=False):
"""
if summary:

info_print = "\n Epoch {} : {} summary Loss : {}".format(epoch, mode,
self.data[mode]['loss'] / self.data[mode]['count'])
info_print = "\n Epoch {:2d} : {} summary Loss : {:.4f} DSC : {:.4f} ".format(epoch, mode,
self.data[mode]['loss'] /
self.data[mode]['count'],
self.data[mode]['dsc'] /
self.data[mode]['count'])

for i in range(len(self.label_names)):
info_print += " {} : {}".format(self.label_names[i],
self.data[mode][self.label_names[i]] / self.data[mode]['count'])
info_print += " {} : {:.4f}".format(self.label_names[i],
self.data[mode][self.label_names[i]] / self.data[mode]['count'])

print(info_print)
else:

info_print = "partial epoch: {} Loss:{}".format(iter, self.data[mode]['loss'] / self.data[mode]['count'])
info_print = "partial epoch: {:.3f} Loss : {:.4f} DSC : {:.4f}".format(iter, self.data[mode]['loss'] /
self.data[mode]['count'],
self.data[mode]['dsc'] /
self.data[mode]['count'])

for i in range(len(self.label_names)):
info_print += " {} : {}".format(self.label_names[i],
self.data[mode][self.label_names[i]] / self.data[mode]['count'])
info_print += " {} : {:.4f}".format(self.label_names[i],
self.data[mode][self.label_names[i]] / self.data[mode]['count'])
print(info_print)

def create_stats_files(self, path):
Expand Down Expand Up @@ -99,7 +105,7 @@ def update_scores(self, iter, loss, channel_score, mode, writer_step):
num_channels = len(channel_score)
self.data[mode]['dsc'] += dice_coeff
self.data[mode]['loss'] += loss
self.data[mode]['count'] = iter
self.data[mode]['count'] = iter + 1

for i in range(num_channels):
self.data[mode][self.label_names[i]] += channel_score[i]
Expand All @@ -119,6 +125,18 @@ def _write_end_of_epoch(self, epoch):
{'train': self.data['train'][self.label_names[i]] / self.data['train']['count'],
'val': self.data['val'][self.label_names[i]] / self.data['train']['count'],
}, epoch)
# TODO write csv files
# self.csv_train.write()
# self.csv_test.write()

# TODO write labels accuracies in csv files

train_csv_line = 'Epoch:{:2d} Loss:{:.4f} DSC:{:.4f}'.format(epoch,
self.data['train']['loss'] / self.data['train'][
'count'],
self.data['train']['dsc'] / self.data['train'][
'count'])
val_csv_line = 'Epoch:{:2d} Loss:{:.4f} DSC:{:.4f}'.format(epoch,
self.data['val']['loss'] / self.data['val'][
'count'],
self.data['val']['dsc'] / self.data['val'][
'count'])
self.csv_train.write(train_csv_line + '\n')
self.csv_val.write(val_csv_line + '\n')
26 changes: 19 additions & 7 deletions tests/train_iseg2017.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
# Python libraries
import argparse, os
import argparse
import os

import torch
from torch.utils.tensorboard import SummaryWriter

# Lib files
import lib.utils as utils
import lib.medloaders as medical_loaders
import lib.medzoo as medzoo

import lib.train as train
# Lib files
import lib.utils as utils
from lib.losses3D import DiceLoss

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
seed = 1777777
torch.manual_seed(seed)
import torch.backends.cudnn as cudnn
import numpy as np


def main():
args = get_arguments()

seed = 1777777
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
np.random.seed(seed)
cudnn.deterministic = True
cudnn.benchmark = True

utils.make_dirs(args.save)
name_model = args.model + "_" + args.dataset_name + "_" + utils.datestr()

Expand All @@ -31,7 +43,7 @@ def main():
criterion = DiceLoss(classes=args.classes)

if args.cuda:
torch.cuda.manual_seed(seed)

model = model.cuda()
print("Model transferred in GPU.....")

Expand Down Expand Up @@ -61,10 +73,10 @@ def get_arguments():
parser.add_argument('--fold_id', default='1', type=str, help='Select subject for fold validation')
parser.add_argument('--lr', default=1e-3, type=float,
help='learning rate (default: 1e-3)')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--cuda', action='store_true', default=True)
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--model', type=str, default='DENSEVOXELNET',
parser.add_argument('--model', type=str, default='VNET',
choices=('VNET', 'VNET2', 'UNET3D', 'DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
parser.add_argument('--opt', type=str, default='sgd',
choices=('sgd', 'adam', 'rmsprop'))
Expand Down
39 changes: 15 additions & 24 deletions tests/train_with_trainer_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import argparse
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn

import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
# Lib files
Expand All @@ -15,18 +11,13 @@

def main():
args = get_arguments()
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
## FOR REPRODUCIBILITY OF RESULTS
seed = 1777777
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
np.random.seed(seed)
cudnn.deterministic = True
# FOR FASTER GPU TRAINING WHEN INPUT SIZE DOESN'T VARY
# cudnn.benchmark = True
utils.reproducibility(args, seed)

utils.make_dirs(args.save)
utils.save_arguments(args, args.save)

training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
path='.././datasets')
Expand All @@ -46,22 +37,22 @@ def main():

def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=1)
parser.add_argument('--dataset_name', type=str, default="brats2018")
parser.add_argument('--batchSz', type=int, default=4)
parser.add_argument('--dataset_name', type=str, default="iseg2017")
parser.add_argument('--dim', nargs="+", type=int, default=(32, 32, 32))
parser.add_argument('--nEpochs', type=int, default=10)
parser.add_argument('--classes', type=int, default=5)
parser.add_argument('--samples_train', type=int, default=10)
parser.add_argument('--samples_val', type=int, default=10)
parser.add_argument('--inChannels', type=int, default=4)
parser.add_argument('--inModalities', type=int, default=4)
parser.add_argument('--split', default=0.8, type=float, help='Select percentage of training data(default: 0.8)')
parser.add_argument('--lr', default=1e-3, type=float,
parser.add_argument('--nEpochs', type=int, default=250)
parser.add_argument('--classes', type=int, default=4)
parser.add_argument('--samples_train', type=int, default=1000)
parser.add_argument('--samples_val', type=int, default=100)
parser.add_argument('--inChannels', type=int, default=2)
parser.add_argument('--inModalities', type=int, default=2)
parser.add_argument('--fold_id', default='1', type=str, help='Select subject for fold validation')
parser.add_argument('--lr', default=1e-2, type=float,
help='learning rate (default: 1e-3)')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--cuda', action='store_true', default=True)
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--model', type=str, default='UNET3D',
parser.add_argument('--model', type=str, default='DENSENET2',
choices=('VNET', 'VNET2', 'UNET3D', 'DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
parser.add_argument('--opt', type=str, default='sgd',
choices=('sgd', 'adam', 'rmsprop'))
Expand Down

0 comments on commit 3f8083b

Please sign in to comment.