Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lvwj19 authored Jul 23, 2021
1 parent ed6f3b6 commit ede3017
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions tools/IPABunny_msg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,19 @@
#-----------------------END-----------------------

def train_one_epoch(loader):
logger.log_string('----------------TRAIN STATUS---------------')
logger.log_string('--------------------')
net.train() # set model to training mode

logger.reset_state_dict('train trans loss', 'train rot loss', 'train vs loss', 'train dist')
total_batch = 0
total_seen = 0
loss_sum = 0
rot_loss_sum = 0
vs_loss_sum = 0
trans_loss_sum = 0
dist_mean_sum = 0

for batch_idx, batch_samples in enumerate(loader):
total_batch += 1
labels = {
'rot_label':batch_samples['rot_label'].to(device),
'trans_label':batch_samples['trans_label'].to(device),
Expand All @@ -125,29 +132,43 @@ def train_one_epoch(loader):
losses['total'].backward()
optimizer.step()

dist_mean = torch.mean(torch.norm(pred_results[0].view(-1,3)-labels['trans_label'].view(-1,3), dim=1)).item()
total_seen += (BATCH_SIZE*NUM_POINT)
loss_sum += losses['total'].item()
rot_loss_sum += losses['rot_head'].item()
trans_loss_sum += losses['trans_head'].item()
vs_loss_sum += losses['vis_head'].item()

log_state_dict = {'train trans loss':losses['trans_head'].item(), 'train rot loss':losses['rot_head'].item(),
'train vs loss':losses['vis_head'].item(), 'train dist':dist_mean}
logger.update_state_dict(log_state_dict)
dist_mean = torch.mean(torch.norm(pred_results[0].view(-1,3)-labels['trans_label'].view(-1,3), dim=1)).item()
dist_mean_sum += dist_mean

if batch_idx % DISPLAY_BATCH_STEP == 0 and batch_idx!= 0:
print('Current batch/total batch num: %d/%d'%(batch_idx,len(loader)))
logger.print_state_dict(log=False)
tran_loss_cur = trans_loss_sum/(batch_idx+1)
rot_loss_cur = rot_loss_sum/(batch_idx+1)
dist_mean_cur = dist_mean_sum/(batch_idx+1)
vs_loss_cur = vs_loss_sum/(batch_idx+1)
print('trans_loss: %f\trot_loss: %f\tvs_loss: %f\tmean_dist: %f'%(tran_loss_cur,rot_loss_cur, vs_loss_cur,dist_mean_cur))

print('Current batch/total batch num: %d/%d'%(len(loader),len(loader)))
logger.print_state_dict(log=True)
logger.log_string('train translation loss: %f' % (trans_loss_sum / float(total_batch)))
logger.log_string('train rotation loss: %f' % (rot_loss_sum / float(total_batch)))
logger.log_string('train vis loss: %f' % (vs_loss_sum / float(total_batch)))
logger.log_string('train dist: %f' % (dist_mean_sum / float(total_batch)))


def eval_one_epoch(loader):
logger.log_string('----------------EVAL STATUS---------------')
logger.log_string('--------------------')
net.eval() # set model to eval mode

logger.reset_state_dict('eval trans loss', 'eval rot loss', 'eval vs loss', 'eval dist')

total_batch = 0
total_seen = 0
loss_sum = 0
rot_loss_sum = 0
vs_loss_sum = 0
trans_loss_sum = 0
dist_mean_sum = 0

for batch_idx, batch_samples in enumerate(loader):
total_batch += 1
labels = {
'rot_label':batch_samples['rot_label'].to(device),
'trans_label':batch_samples['trans_label'].to(device),
Expand All @@ -162,15 +183,20 @@ def eval_one_epoch(loader):
with torch.no_grad():
pred_results, losses = net(inputs)

total_seen += (BATCH_SIZE*NUM_POINT)
loss_sum += losses['total'].item()
dist_mean = torch.mean(torch.norm(pred_results[0].view(-1,3)-labels['trans_label'].view(-1,3), dim=1)).item()
rot_loss_sum += losses['rot_head'].item()
trans_loss_sum += losses['trans_head'].item()
vs_loss_sum += losses['vis_head'].item()

log_state_dict = {'eval trans loss':losses['trans_head'].item(), 'eval rot loss':losses['rot_head'].item(),
'eval vs loss':losses['vis_head'].item(), 'eval dist':dist_mean}
logger.update_state_dict(log_state_dict)
dist_mean = torch.mean(torch.norm(pred_results[0].view(-1,3)-labels['trans_label'].view(-1,3), dim=1)).item()
dist_mean_sum += dist_mean

logger.print_state_dict(log=True)
return loss_sum
logger.log_string('eval translation loss: %f' % (trans_loss_sum / float(total_batch)))
logger.log_string('eval rotation loss: %f' % (rot_loss_sum / float(total_batch)))
logger.log_string('eval vis loss: %f' % (vs_loss_sum / float(total_batch)))
logger.log_string('eval dist: %f' % (dist_mean_sum / float(total_batch)))
return (loss_sum / float(total_batch))

def train(start_epoch):
global train_dataset
Expand Down Expand Up @@ -204,4 +230,3 @@ def train(start_epoch):

if __name__=='__main__':
train(start_epoch)

0 comments on commit ede3017

Please sign in to comment.