-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
338 lines (278 loc) · 16.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import argparse, os, sys, time, gc, datetime
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from datasets import find_dataset_def
from models import *
from utils import *
import torch.distributed as dist
parser = argparse.ArgumentParser(description='A PyTorch Implementation of Efficient edge-preserving multi-view stereo network')
parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile'])
parser.add_argument('--model', default='mvsnet', help='select model')
parser.add_argument('--device', default='cuda', help='select device')
parser.add_argument('--gpu_device', type=str, default='2', help='gpu no.')
parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
parser.add_argument('--trainpath', default='', help='train datapath')
parser.add_argument('--testpath', help='test datapath')
parser.add_argument('--trainlist', default='lists/dtu/train.txt', help='train list')
parser.add_argument('--testlist', default='lists/dtu/test.txt', help='test list')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--lrepochs', type=str, default="6,8,9:2", help='epoch ids to downscale lr and the downscale rate')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--batch_size', type=int, default=4, help='train batch size')
parser.add_argument('--numdepth', type=int, default=128, help='the number of depth values')
parser.add_argument('--interval_scale', type=float, default=1.0, help='the number of depth values')
parser.add_argument('--num_views', type=int, default=5, help='view numbers')
parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
parser.add_argument('--logdir', default='./checkpoints/', help='the directory to save checkpoints/logs')
parser.add_argument('--resume', action='store_true', help='continue to train the model')
parser.add_argument('--summary_freq', type=int, default=100, help='print and summary frequency')
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
parser.add_argument('--eval_freq', type=int, default=1, help='eval freq')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
parser.add_argument('--pin_m', action='store_true', help='data loader pin memory')
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--ndepths', type=str, default="32,16,8", help='ndepths')
parser.add_argument('--depth_inter_r', type=str, default="4,2,1", help='depth_intervals_ratio')
parser.add_argument('--dlossw', type=str, default="0.5,1.0,2.0", help='depth loss weight for different stage')
parser.add_argument('--num_groups', type=str, default="8,8,8,8,8", help='num_groups')
parser.add_argument('--sync_bn', action='store_true',help='enabling apex sync BN.')
parser.add_argument('--robust_train', action='store_true')
parser.add_argument('--random_crop', action='store_true')
parser.add_argument('--resize_wh', type=str, default='')
parser.add_argument('--crop_wh', type=str, default='')
parser.add_argument('--offsetnet_only', action='store_true')
parser.add_argument('--backbone_only', action='store_true')
parser.add_argument('--ph_w', type=float, default='')
args = parser.parse_args()
cudnn.benchmark = True
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_device
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
is_distributed = num_gpus > 1
# main function
def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args):
milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')]
lr_gamma = 1 / float(args.lrepochs.split(':')[1])
lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500,
last_epoch=len(TrainImgLoader) * start_epoch - 1)
for epoch_idx in range(start_epoch, args.epochs):
print('Epoch {}:'.format(epoch_idx))
# global_step = len(TrainImgLoader) * epoch_idx
# training
for batch_idx, sample in enumerate(TrainImgLoader):
start_time = time.time()
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
do_summary = global_step % args.summary_freq == 0
loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args)
lr_scheduler.step()
if (not is_distributed) or (dist.get_rank() == 0):
if do_summary:
save_scalars(logger, 'train', scalar_outputs, global_step)
save_images(logger, 'train', image_outputs, global_step)
print(
"Epoch {}/{}, Iter {}/{}, lr {:.6f}, train loss = {:.3f}, depth loss = {:.3f}, ph_loss = {:3f}, res_loss = {:3f}, less1 = {:3f}, less3 = {:3f}, time = {:3f}".format(
epoch_idx, args.epochs, batch_idx, len(TrainImgLoader),
optimizer.param_groups[0]["lr"], loss,
scalar_outputs['depth_loss'], scalar_outputs['ph_loss'], scalar_outputs['res_loss'], scalar_outputs["less1"], scalar_outputs["less3"],
time.time() - start_time))
del scalar_outputs, image_outputs
# checkpoint
if (not is_distributed) or (dist.get_rank() == 0):
if (epoch_idx + 1) % args.save_freq == 0:
torch.save({
'epoch': epoch_idx,
'model': model.module.state_dict(),
'optimizer': optimizer.state_dict()},
"{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx))
gc.collect()
def test(model, model_loss, TestImgLoader, args):
avg_test_scalars = DictAverageMeter()
for batch_idx, sample in enumerate(TestImgLoader):
start_time = time.time()
loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args)
avg_test_scalars.update(scalar_outputs)
del scalar_outputs, image_outputs
if (not is_distributed) or (dist.get_rank() == 0):
print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss,
time.time() - start_time))
if batch_idx % 100 == 0:
print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean()))
if (not is_distributed) or (dist.get_rank() == 0):
print("final", avg_test_scalars.mean())
def train_sample(model, model_loss, optimizer, sample, args):
model.train()
optimizer.zero_grad()
sample_cuda = tocuda(sample)
depth_gt_ms = sample_cuda["depth"]
mask_ms = sample_cuda["mask"]
num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd])
depth_gt = depth_gt_ms["stage{}".format(str(num_stage-2))]
mask = mask_ms["stage{}".format(str(num_stage-2))]
outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_min"], sample_cuda["depth_max"], sample_cuda["depth_interval"])
depth_est = outputs["depth"]
loss, depth_loss, ph_loss, res_loss, less1, less3 = model_loss(outputs, depth_gt_ms, mask_ms, sample_cuda["proj_matrices"], sample_cuda["imgs"], sample_cuda["depth_interval"], args.offsetnet_only, args.backbone_only, args.ph_w, dlossw=[float(e) for e in args.dlossw.split(",") if e])
loss.backward()
optimizer.step()
scalar_outputs = {"loss": loss,
"depth_loss": depth_loss,
"ph_loss": ph_loss,
"res_loss": res_loss,
"less1": less1*100,
"less3": less3*100,
"abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),
"thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),
"thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),
"thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),}
image_outputs = {"depth_est": depth_est * mask,
"depth_est_nomask": depth_est,
"depth_gt": sample["depth"]["stage1"],
"ref_img": sample["imgs"][:, 0],
"mask": sample["mask"]["stage1"],
"errormap": (depth_est - depth_gt).abs() * mask,
}
if is_distributed:
scalar_outputs = reduce_scalar_outputs(scalar_outputs)
return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)
@make_nograd_func
def test_sample_depth(model, model_loss, sample, args):
if is_distributed:
model_eval = model.module
else:
model_eval = model
model_eval.eval()
sample_cuda = tocuda(sample)
depth_gt_ms = sample_cuda["depth"]
mask_ms = sample_cuda["mask"]
num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd])
depth_gt = depth_gt_ms["stage{}".format(num_stage)]
mask = mask_ms["stage{}".format(num_stage)]
outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_min"], sample_cuda["depth_max"], sample_cuda["depth_interval"])
depth_est = outputs["depth"]
loss, depth_loss, less1, less3 = model_loss(outputs, depth_gt_ms, mask_ms, sample_cuda["depth_interval"], dlossw=[float(e) for e in args.dlossw.split(",") if e])
scalar_outputs = {"loss": loss,
"depth_loss": depth_loss,
"abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),
"thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),
"thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),
"thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),
"thres14mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 14),
"thres20mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 20),
"thres2mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [0, 2.0]),
"thres4mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [2.0, 4.0]),
"thres8mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [4.0, 8.0]),
"thres14mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [8.0, 14.0]),
"thres20mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [14.0, 20.0]),
"thres>20mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [20.0, 1e5]),
}
image_outputs = {"depth_est": depth_est * mask,
"depth_est_nomask": depth_est,
"depth_gt": sample["depth"]["stage1"],
"ref_img": sample["imgs"][:, 0],
"mask": sample["mask"]["stage1"],
"errormap": (depth_est - depth_gt).abs() * mask}
if is_distributed:
scalar_outputs = reduce_scalar_outputs(scalar_outputs)
return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)
if __name__ == '__main__':
# parse arguments and check
if args.sync_bn:
assert is_distributed, "must be distributed"
if args.resume:
assert args.mode == "train"
assert args.loadckpt is None
if args.testpath is None:
args.testpath = args.trainpath
if is_distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
# set_random_seed(args.seed)
device = torch.device(args.device)
if (not is_distributed) or (dist.get_rank() == 0):
# create logger for mode "train" and "testall"
if args.mode == "train":
if not os.path.isdir(args.logdir):
os.makedirs(args.logdir)
current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
print("current time", current_time_str)
print("creating new summary file")
logger = SummaryWriter(args.logdir)
print("argv:", sys.argv[1:])
print_args(args)
# model, optimizer
model = MVSNet(ndepths=[int(nd) for nd in args.ndepths.split(",") if nd],
depth_interals_ratio=[float(d_i) for d_i in args.depth_inter_r.split(",") if d_i],
num_groups=[int(ng) for ng in args.num_groups.split(",") if ng], offsetnet_only=args.offsetnet_only, backbone_only=args.backbone_only)
model.to(device)
#model_loss = mvsnet_loss
if args.dataset == "blendedmvs":
model_loss = mvsnet_loss
else:
model_loss = mvsnet_loss_dtu
if args.sync_bn:
print("synced BN of model")
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
# load parameters
start_epoch = 0
if args.resume:
saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))
# use the latest checkpoint file
loadckpt = os.path.join(args.logdir, saved_models[-1])
print("resuming", loadckpt)
state_dict = torch.load(loadckpt, map_location=torch.device("cpu"))
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
start_epoch = state_dict['epoch'] + 1
elif args.loadckpt:
# load checkpoint file specified by args.loadckpt
print("loading model {}".format(args.loadckpt))
state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu"))
model.load_state_dict(state_dict['model'])
else:
for m in model.modules():
if any([isinstance(m, T) for T in [nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d]]):
if m.weight.requires_grad:
nn.init.xavier_uniform_(m.weight)
elif any([isinstance(m, T) for T in [nn.BatchNorm2d, nn.BatchNorm3d]]):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if (not is_distributed) or (dist.get_rank() == 0):
print("start at epoch {}".format(start_epoch))
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
if is_distributed:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank,
# find_unused_parameters=False,
# this should be removed if we update BatchNorm stats
# broadcast_buffers=False,
)
else:
if torch.cuda.is_available():
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
# dataset, dataloader
MVSDataset = find_dataset_def(args.dataset)
train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", args.num_views, args.resize_wh, args.crop_wh, args.numdepth, args.interval_scale, args.robust_train, args.random_crop)
test_dataset = MVSDataset(args.testpath, args.testlist, "test", args.num_views, args.resize_wh, args.crop_wh, args.numdepth, args.interval_scale, args.robust_train, args.random_crop)
if is_distributed:
train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=1, drop_last=True, pin_memory=args.pin_m)
TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=1, drop_last=False, pin_memory=args.pin_m)
else:
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=1, drop_last=True, pin_memory=args.pin_m)
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=1, drop_last=False, pin_memory=args.pin_m)
if args.mode == "train":
train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args)
elif args.mode == "test":
test(model, model_loss, TestImgLoader, args)
else:
raise NotImplementedError