Skip to content

Commit

Permalink
[main] adapt for parrots
Browse files Browse the repository at this point in the history
  • Loading branch information
T.T. Tang committed Dec 17, 2021
1 parent e6e3304 commit f36787b
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data.distributed import DistributedSampler

if torch.__version__ == 'parrots':
from tensorboardX import SummaryWriter
from pavi import SummaryWriter
else:
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -52,11 +52,15 @@ def setup(rank, world_size):


def get_dataloader(rank, world_size, batch_size, pin_memory=False, num_workers=0):
train = Comma2k19SequenceDataset('data/comma2k19_train_non_overlap.txt', 'data/comma2k19/','train', use_memcache=False)
val = Comma2k19SequenceDataset('data/comma2k19_val_non_overlap.txt', 'data/comma2k19/','val', use_memcache=False)
train = Comma2k19SequenceDataset('data/comma2k19_train_non_overlap.txt', 's3://comma2k19/','train', use_memcache=True)
val = Comma2k19SequenceDataset('data/comma2k19_val_non_overlap.txt', 's3://comma2k19/','val', use_memcache=True)

train_sampler = DistributedSampler(train, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
val_sampler = DistributedSampler(val, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
if torch.__version__ == 'parrots':
dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True)
else:
dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
train_sampler = DistributedSampler(train, **dist_sampler_params)
val_sampler = DistributedSampler(val, **dist_sampler_params)

loader_args = dict(num_workers=num_workers, persistent_workers=True if num_workers > 0 else False, prefetch_factor=2, pin_memory=pin_memory)
train_loader = DataLoader(train, batch_size, sampler=train_sampler, **loader_args)
Expand Down Expand Up @@ -136,18 +140,19 @@ def main(rank, world_size, args):
loss = MultipleTrajectoryPredictionLoss(args.mtp_alpha, args.M, args.num_pts, distance_type='angle')

num_steps = 0
disable_tqdm = True or (rank != 0)

for epoch in tqdm(range(args.epochs), disable=(rank != 0)):
for epoch in tqdm(range(args.epochs), disable=disable_tqdm):
train_dataloader.sampler.set_epoch(epoch)

for batch_idx, data in enumerate(tqdm(train_dataloader, leave=False, disable=(rank != 0))):
for batch_idx, data in enumerate(tqdm(train_dataloader, leave=False, disable=disable_tqdm)):
seq_inputs, seq_labels = data['seq_input_img'].cuda(), data['seq_future_poses'].cuda()
bs = seq_labels.size(0)
seq_length = seq_labels.size(1)

hidden = torch.zeros((2, bs, 512)).cuda()
total_loss = 0
for t in tqdm(range(seq_length), leave=False, disable=(rank != 0)):
for t in tqdm(range(seq_length), leave=False, disable=disable_tqdm):
num_steps += 1
inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :]
pred_cls, pred_trajectory, hidden = model(inputs, hidden)
Expand All @@ -156,6 +161,7 @@ def main(rank, world_size, args):
total_loss += (cls_loss + args.mtp_alpha * reg_loss.mean()) / model.module.optimize_per_n_step

if rank == 0:
# TODO: add a customized log function
writer.add_scalar('loss/cls', cls_loss, num_steps)
writer.add_scalar('loss/reg', reg_loss.mean(), num_steps)
writer.add_scalar('loss/reg_x', reg_loss[0], num_steps)
Expand All @@ -182,13 +188,15 @@ def main(rank, world_size, args):
writer.add_scalar('loss/total', total_loss, num_steps)

lr_scheduler.step()
if (epoch + 1) % 1 == 0:
if (epoch + 1) % 1 == 0: # TODO: Add to args
if rank == 0:
# save model
ckpt_path = os.path.join(writer.get_logdir(), 'epoch_%d.pth' % epoch)
ckpt_path = os.path.join(writer.log_dir, 'epoch_%d.pth' % epoch)
torch.save(model.module.state_dict(), ckpt_path)
print('[Epoch %d] checkpoint saved at %s' % (epoch, ckpt_path))
print('skipping val...')

dist.all_gather() # TODO
continue
for batch_idx, data in enumerate(val_dataloader):
data = data.cuda()
Expand Down

0 comments on commit f36787b

Please sign in to comment.