forked from mobaidoctor/med-ddpm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_brats.py
112 lines (99 loc) · 3.92 KB
/
train_brats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#-*- coding:utf-8 -*-
# +
from torchvision.transforms import RandomCrop, Compose, ToPILImage, Resize, ToTensor, Lambda
from diffusion_model.trainer_brats import GaussianDiffusion, Trainer
from diffusion_model.unet_brats import create_model
from dataset_brats import NiftiImageGenerator, NiftiPairImageGenerator
import argparse
import torch
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# -
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--seg_folder', type=str, default="dataset/brats2021/seg/")
parser.add_argument('-t1', '--t1_folder', type=str, default="dataset/brats2021/t1/")
parser.add_argument('-t2', '--t1ce_folder', type=str, default="dataset/brats2021/t1ce/")
parser.add_argument('-t3', '--t2_folder', type=str, default="dataset/brats2021/t2/")
parser.add_argument('-t4', '--flair_folder', type=str, default="dataset/brats2021/flair/")
parser.add_argument('--input_size', type=int, default=192)
parser.add_argument('--depth_size', type=int, default=144)
parser.add_argument('--num_channels', type=int, default=64)
parser.add_argument('--num_res_blocks', type=int, default=2)
parser.add_argument('--batchsize', type=int, default=1)
parser.add_argument('--epochs', type=int, default=10000000)
parser.add_argument('--timesteps', type=int, default=250)
parser.add_argument('--save_and_sample_every', type=int, default=1000)
parser.add_argument('--with_condition', action='store_true')
parser.add_argument('-r', '--resume_weight', type=str, default="model/model_brats.pt")
args = parser.parse_args()
seg_folder = args.seg_folder
t1_folder = args.t1_folder
t1ce_folder = args.t1ce_folder
t2_folder = args.t2_folder
flair_folder = args.flair_folder
input_size = args.input_size
depth_size = args.depth_size
num_channels = args.num_channels
num_res_blocks = args.num_res_blocks
save_and_sample_every = args.save_and_sample_every
with_condition = args.with_condition
resume_weight = args.resume_weight
# input tensor: (B, 1, H, W, D) value range: [-1, 1]
transform = Compose([
Lambda(lambda t: torch.tensor(t).float()),
Lambda(lambda t: t.permute(3, 0, 1, 2)),
Lambda(lambda t: t.transpose(3, 1)),
])
input_transform = Compose([
Lambda(lambda t: torch.tensor(t).float()),
Lambda(lambda t: t.permute(3, 0, 1, 2)),
Lambda(lambda t: t.transpose(3, 1)),
])
if with_condition:
dataset = NiftiPairImageGenerator(
seg_folder,
t1_folder,
t1ce_folder,
t2_folder,
flair_folder,
input_size=input_size,
depth_size=depth_size,
transform=input_transform if with_condition else transform,
target_transform=transform,
full_channel_mask=True
)
else:
print("Please modify your code to unconditional generation")
in_channels = 4+4 if with_condition or with_pairwised else 1
out_channels = 4
model = create_model(input_size, num_channels, num_res_blocks, in_channels=in_channels, out_channels=out_channels).cuda()
diffusion = GaussianDiffusion(
model,
image_size = input_size,
depth_size = depth_size,
timesteps = args.timesteps, # number of steps
loss_type = 'hybrid', # L1 or L2
with_condition=with_condition,
channels=out_channels
).cuda()
if len(resume_weight) > 0:
weight = torch.load(resume_weight, map_location='cuda')
diffusion.load_state_dict(weight['ema'])
print("Model Loaded!")
trainer = Trainer(
diffusion,
dataset,
image_size = input_size,
depth_size = depth_size,
train_batch_size = args.batchsize,
train_lr = 1e-5,
train_num_steps = args.epochs, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
fp16 = False,#True, # turn on mixed precision training with apex
save_and_sample_every = save_and_sample_every,
results_folder = './results_brats',
with_condition=with_condition
)
trainer.train()