-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain_DCAdapt.py
493 lines (411 loc) · 18.6 KB
/
main_DCAdapt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
import os.path as osp
from functools import partial
import gc
import traceback
from models.epc_loss import ConsistLoss
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import transforms
import datasets
import models
import cmd_args
from main_utils import *
from tqdm import tqdm
from models import EPE3DLoss
from evaluation_bnn import evaluate, evaluate_MT
from models import ConsistLoss
torch.distributed.init_process_group(backend="nccl") # distributed
torch.autograd.set_detect_anomaly(True)
class WeightEMA(object):
"""
Exponential moving average weight optimizer for mean teacher model
"""
def __init__(self, params, src_params, alpha=0.999):
self.params = list(params)
self.src_params = list(src_params)
self.alpha = alpha
for p, src_p in zip(self.params, self.src_params):
p.data[:] = src_p.data[:]
print('teacher model initialized ...')
def step(self):
one_minus_alpha = 1.0 - self.alpha
for p, src_p in zip(self.params, self.src_params):
p.data.mul_(self.alpha)
p.data.add_(src_p.data * one_minus_alpha)
def main():
# ensure numba JIT is on
if 'NUMBA_DISABLE_JIT' in os.environ:
del os.environ['NUMBA_DISABLE_JIT']
# parse arguments
global args
args = cmd_args.parse_args_from_yaml(sys.argv[-1])
# -------------------- logging args --------------------
# if osp.exists(args.ckpt_dir):
# to_continue = query_yes_no('Attention!!!, ckpt_dir already exists!\
# Whether to continue?',
# default=None)
# if not to_continue:
# sys.exit(1)
os.makedirs(args.ckpt_dir, mode=0o777, exist_ok=True)
logger = Logger(osp.join(args.ckpt_dir, 'log'))
logger.log('sys.argv:\n' + ' '.join(sys.argv))
local_rank = torch.distributed.get_rank()
print('local rank ', local_rank)
torch.cuda.set_device(local_rank)
os.environ['NUMBA_NUM_THREADS'] = str(args.workers)
logger.log('NUMBA NUM THREADS\t' + os.environ['NUMBA_NUM_THREADS'])
for arg in sorted(vars(args)):
logger.log('{:20s} {}'.format(arg, getattr(args, arg)))
logger.log('')
# -------------------- dataset & loader --------------------
if not args.evaluate:
# source dataset
train_dataset_source = datasets.__dict__[args.source_dataset](
train=True,
transform=transforms.Augmentation(args.aug_together,
args.aug_pc2,
args.data_process,
args.num_points,
args.allow_less_points),
gen_func=transforms.GenerateDataUnsymmetric(args),
args=args
)
logger.log('train_dataset_source: ' + str(train_dataset_source))
'''
train_loader = torch.utils.data.DataLoader(
train_dataset_source,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) % (2 ** 32))
)
'''
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_source)
train_loader = torch.utils.data.DataLoader(
train_dataset_source,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
# target dataset
train_dataset_target = datasets.__dict__[args.target_dataset](
train=True,
transform=transforms.Augmentation(args.aug_together,
args.aug_pc2,
args.data_process,
args.num_points,
args.allow_less_points),
gen_func=transforms.GenerateDataUnsymmetric(args),
args=args
)
logger.log('train_dataset_target: ' + str(train_dataset_target))
'''
target_loader = torch.utils.data.DataLoader(
train_dataset_target,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) % (2 ** 32))
)
'''
target_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_target)
target_loader = torch.utils.data.DataLoader(
train_dataset_target,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
sampler=target_sampler
)
# -----------------------val_dataset ----- target_dataset-----------------------------------
val_dataset = datasets.__dict__[args.val_dataset](
train=False,
transform=transforms.ProcessData(args.data_process,
args.num_points,
args.allow_less_points),
gen_func=transforms.GenerateDataUnsymmetric(args),
args=args
)
logger.log('val_dataset: ' + str(val_dataset))
'''
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) % (2 ** 32))
)
'''
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
sampler=val_sampler
)
# -------------------- create model --------------------
# logger.log("=> creating model '{}'".format(args.arch))
# ---- student model ----
student_model = models.__dict__[args.arch](args)
if not args.evaluate:
init_func = partial(init_weights_multi, init_type=args.init, gain=args.gain)
student_model.apply(init_func)
# logger.log(student_model)
student_model = student_model.cuda()
device = torch.device('cuda:%d' % local_rank)
student_model = student_model.to(device)
student_model = torch.nn.parallel.DistributedDataParallel(student_model, device_ids=[local_rank])
'''
if torch.cuda.device_count() > 1:
print('device_count ', torch.cuda.device_count())
cuda_device = list(range(torch.cuda.device_count()))
#student_model = torch.nn.DataParallel(student_model, device_ids=cuda_device)
student_model=torch.nn.parallel.DistributedDataParallel(student_model)
else:
student_model = torch.nn.DataParallel(student_model).cuda()
'''
criterion_EPE3D = EPE3DLoss().cuda().to(device)
criterion_consist = ConsistLoss().cuda().to(device)
# ---- teacher model ----
teacher_model = models.__dict__[args.arch](args)
teacher_model = teacher_model.cuda()
device = torch.device('cuda:%d' % local_rank)
teacher_model = teacher_model.to(device)
teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[local_rank])
'''
if torch.cuda.device_count() > 1:
#teacher_model = torch.nn.DataParallel(teacher_model, device_ids=cuda_device)
teacher_model=torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[local_rank])
else:
teacher_model = torch.nn.DataParallel(teacher_model).cuda()
'''
if args.evaluate:
torch.backends.cudnn.enabled = False
else:
cudnn.benchmark = True
# https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
# But if your input sizes changes at each iteration,
# then cudnn will benchmark every time a new size appears,
# possibly leading to worse runtime performances.
# -------------------- resume --------------------
if args.resume:
if osp.isfile(args.resume):
logger.log("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location=device)
args.start_epoch = checkpoint['epoch']
student_model.load_state_dict(checkpoint['state_dict'], strict=True)
logger.log("=> loaded checkpoint '{}' (start epoch {}, min loss {})"
.format(args.resume, checkpoint['epoch'], checkpoint['min_loss']))
else:
logger.log("=> no checkpoint found at '{}'".format(args.resume))
checkpoint = None
else:
args.start_epoch = 0
# -------------------- evaluation --------------------
if args.evaluate:
res_str = evaluate(val_loader, student_model, logger, args)
# res_str = evaluate_MT(val_loader, student_model, teacher_model, logger, args)
logger.close()
return res_str
# -------------------- optimizer --------------------
# for student model
student_model_params = []
for key, value in dict(student_model.named_parameters()).items():
if value.requires_grad:
student_model_params += [value]
# student optim
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, student_model.parameters()),
lr=args.lr,
weight_decay=0)
if args.resume and (checkpoint is not None):
optimizer.load_state_dict(checkpoint['optimizer'])
teacher_model_params = []
for key, value in dict(teacher_model.named_parameters()).items():
if value.requires_grad:
teacher_model_params += [value]
value.requires_grad = False
# teacher optim
teacher_optimizer = WeightEMA(teacher_model_params, student_model_params, alpha=args.alpha)
# ------------ resume for teacher model -------------
if hasattr(args, 'resume_t') and args.resume_t:
resume_t = args.resume_t
else:
resume_t = args.resume
if osp.isfile(resume_t):
logger.log("=> loading checkpoint teacher '{}'".format(resume_t))
checkpoint = torch.load(resume_t, map_location=device)
args.start_epoch = checkpoint['epoch']
teacher_model.load_state_dict(checkpoint['state_dict'], strict=True)
logger.log("=> loaded checkpoint '{}' (start epoch {}, min loss {})"
.format(resume_t, checkpoint['epoch'], checkpoint['min_loss']))
if hasattr(args, 'reset_lr') and args.reset_lr:
print('reset lr')
reset_learning_rate(optimizer, args)
# -------------------- main loop --------------------
min_train_loss = None
best_train_epoch = None
best_val_epoch = None
do_eval = True
for epoch in range(args.start_epoch, args.epochs):
old_lr = optimizer.param_groups[0]['lr']
adjust_learning_rate(optimizer, epoch, args)
lr = optimizer.param_groups[0]['lr']
if old_lr != lr:
print('Switch lr!')
logger.log('lr: ' + str(optimizer.param_groups[0]['lr']))
# target_loader.sampler.set_epoch(epoch) # shuffle the target dataloader
train_loss = train(train_loader, target_loader, student_model, teacher_model, criterion_EPE3D,
criterion_consist, optimizer, teacher_optimizer, epoch, logger, device)
gc.collect()
is_train_best = True if best_train_epoch is None else (train_loss < min_train_loss)
if is_train_best:
min_train_loss = train_loss
best_train_epoch = epoch
if do_eval:
logger.log('--------eval for student---------')
val_loss = validate(val_loader, student_model, criterion_EPE3D, logger)
gc.collect()
is_val_best = True if best_val_epoch is None else (val_loss < min_val_loss)
if is_val_best:
min_val_loss = val_loss
best_val_epoch = epoch
logger.log("New min val loss!")
logger.log('--------eval for teacher---------')
val_loss_2 = validate(val_loader, teacher_model, criterion_EPE3D, logger)
gc.collect()
min_loss = min_val_loss if do_eval else min_train_loss
is_best = is_val_best if do_eval else is_train_best
# for student
save_checkpoint({
'epoch': epoch + 1, # next start epoch
'arch': args.arch,
'state_dict': student_model.state_dict(),
'min_loss': min_loss,
'optimizer': optimizer.state_dict(),
}, is_best, args.ckpt_dir)
# for teacher
save_checkpoint_t({
'epoch': epoch + 1, # next start epoch
'arch': args.arch,
'state_dict': teacher_model.state_dict(),
'min_loss': min_loss,
'optimizer': optimizer.state_dict(),
}, is_best, args.ckpt_dir)
train_str = 'Best train loss: {:.5f} at epoch {:3d}'.format(min_train_loss, best_train_epoch)
logger.log(train_str)
if do_eval:
val_str = 'Best val loss:a {:.5f} at epoch {:3d}'.format(min_val_loss, best_val_epoch)
logger.log(val_str)
logger.close()
result_str = val_str if do_eval else train_str
return result_str
def split_and_load(batch):
new_batch = []
for i, data in enumerate(batch):
new_data = [x for x in data]
new_batch.append(new_data)
return new_batch
def train(train_loader, target_loader, student_model, teacher_model, criterion_epe, criterion_consist, optimizer,
teacher_optimizer, epoch, logger, device):
epe3d_losses = AverageMeter()
consis_losses = AverageMeter()
student_model.train()
for i, (train_batch, target_batch) in tqdm(enumerate(zip(train_loader, target_loader)),
total=min(len(train_loader), len(target_loader))):
pc1, pc2, sf, generated_data, path = train_batch # source data
pc1_target, pc2_target, generated_data_target, pc1_target_2, pc2_target_2, generated_data_target_2, path_target, _ = target_batch # target data
try:
cur_sf = sf.cuda(non_blocking=True).to(device)
cur_pc1_target = pc1_target.cuda(non_blocking=True).to(device)
cur_pc1_target_2 = pc1_target_2.cuda(non_blocking=True).to(device) # augmented -> target_2
cur_pc2_target = pc2_target.cuda(non_blocking=True).to(device)
# student(source), student(target_2)
output = student_model(pc1, pc2, generated_data)
output_target_student = student_model(pc1_target_2, pc2_target,
generated_data_target_2)
# teacher(target)
output_target_teacher = teacher_model(pc1_target, pc2_target, generated_data_target)
with torch.no_grad():
output_target_teacher = output_target_teacher.detach() # stop grad for teacher
# supervised loss for source domain
epe3d_loss = criterion_epe(input=output, target=cur_sf).mean()
# consistency loss for target domain
consistency_loss = criterion_consist(cur_pc1_target, output_target_teacher.detach(), cur_pc2_target,
cur_pc1_target_2 + output_target_student)
# total loss
loss = epe3d_loss + consistency_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
teacher_model.zero_grad()
teacher_optimizer.step()
epe3d_losses.update(epe3d_loss.item(), pc1.size(0)) # batch size can only be 1 for now
consis_losses.update(consistency_loss.item(), pc1.size(0))
if i % args.print_freq == 0:
logger.log('Epoch: [{0}][{1}/{2}]\t'
'EPE3D Loss {epe3d_losses_.val:.4f} ({epe3d_losses_.avg:.4f}) Consist {consis_losses_.val:.4f} ({consis_losses_.avg:.4f})'
.format(epoch + 1, i + 1, len(train_loader), epe3d_losses_=epe3d_losses,
consis_losses_=consis_losses), end='')
logger.log('')
except RuntimeError as ex:
logger.log("in TRAIN, RuntimeError " + repr(ex))
logger.log("batch idx: " + str(i) + ' path: ' + path[0])
traceback.print_tb(ex.__traceback__, file=logger.out_fd)
traceback.print_tb(ex.__traceback__)
if "CUDA error: out of memory" in str(ex) or "cuda runtime error" in str(ex):
logger.log("out of memory, continue")
del pc1, pc2, sf, generated_data
if 'output' in locals():
del output
torch.cuda.empty_cache()
gc.collect()
else:
sys.exit(1)
logger.log(' * Train EPE3D {epe3d_losses_.avg:.4f}'.format(epe3d_losses_=epe3d_losses))
return epe3d_losses.avg
def validate(val_loader, model, criterion, logger):
epe3d_losses = AverageMeter()
model.eval()
with torch.no_grad():
for i, (pc1, pc2, sf, generated_data, path) in enumerate(val_loader):
try:
cur_sf = sf.cuda(non_blocking=True)
output = model(pc1, pc2, generated_data)
epe3d_loss = criterion(input=output, target=cur_sf)
epe3d_losses.update(epe3d_loss.mean().item())
if i % args.print_freq == 0:
logger.log('Test: [{0}/{1}]\t'
'EPE3D loss {epe3d_losses_.val:.4f} ({epe3d_losses_.avg:.4f})'
.format(i + 1, len(val_loader),
epe3d_losses_=epe3d_losses))
except RuntimeError as ex:
logger.log("in VAL, RuntimeError " + repr(ex))
traceback.print_tb(ex.__traceback__, file=logger.out_fd)
traceback.print_tb(ex.__traceback__)
if "CUDA error: out of memory" in str(ex) or "cuda runtime error" in str(ex):
logger.log("out of memory, continue")
del pc1, pc2, sf, generated_data
torch.cuda.empty_cache()
gc.collect()
print('remained objects after OOM crash')
else:
sys.exit(1)
logger.log(' * EPE3D loss {epe3d_loss_.avg:.4f}'.format(epe3d_loss_=epe3d_losses))
return epe3d_losses.avg
if __name__ == '__main__':
pid = os.getpid()
print('pid: ', pid)
main()