Skip to content

Commit

Permalink
Add training reference for optical flow models (pytorch#5027)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 7, 2021
1 parent 47bd962 commit 4dd8b5c
Show file tree
Hide file tree
Showing 2 changed files with 616 additions and 0 deletions.
334 changes: 334 additions & 0 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
import argparse
import warnings
from pathlib import Path

import torch
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small


def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)

things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)

kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)

hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)

# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")


@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size

model.eval()

sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.num_workers,
)

num_flow_updates = num_flow_updates or args.num_flow_updates

def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]

image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]

image1, image2 = image1.cuda(), image2.cuda()

padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)

flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()

metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)

# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)

logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")

num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size

num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)

if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])

logger.synchronize_between_processes()
print(header, logger)


def validate(model, args):
val_datasets = args.val_dataset or []
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and args.rank == 0:
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)

val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
)
_validate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")


def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
for data_blob in logger.log_every(train_loader):

optimizer.zero_grad()

image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)

loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)

metrics.pop("f1")
logger.update(loss=loss, **metrics)

loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

optimizer.step()
scheduler.step()

current_step += 1

if current_step == args.num_steps:
return True, current_step

return False, current_step


def main(args):
utils.setup_ddp(args)

model = raft_small() if args.small else raft_large()
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

if args.resume is not None:
d = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True)

if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
validate(model, args)
return

print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

torch.backends.cudnn.benchmark = True

model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)

train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)

sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
total_steps=args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)

logger = utils.MetricLogger()

done = False
current_epoch = current_step = 0
while not done:
print(f"EPOCH {current_epoch}")

sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
done, current_step = train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
current_step=current_step,
args=args,
)

# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)

current_epoch += 1

if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")

if current_epoch % args.val_freq == 0 or done:
validate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)


def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument(
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
)
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)

parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")

parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
parser.add_argument("--batch-size", type=int, default=6)

parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")

parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)

parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")

parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)

parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")

parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")

parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)
Loading

0 comments on commit 4dd8b5c

Please sign in to comment.