diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md new file mode 100644 index 00000000000..660ee1b0c38 --- /dev/null +++ b/references/optical_flow/README.md @@ -0,0 +1,57 @@ +# Optical flow reference training scripts + +This folder contains reference training scripts for optical flow. +They serve as a log of how to train specific models, so as to provide baseline +training and evaluation scripts to quickly bootstrap research. + + +### RAFT Large + +The RAFT large model was trained on Flying Chairs and then on Flying Things. +Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The +rest of the hyper-parameters are exactly the same as the original RAFT training +recipe from https://github.com/princeton-vl/RAFT. + +``` +torchrun --nproc_per_node 8 --nnodes 1 train.py \ + --dataset-root $dataset_root \ + --name $name_chairs \ + --model raft_large \ + --train-dataset chairs \ + --batch-size 2 \ + --lr 0.0004 \ + --weight-decay 0.0001 \ + --num-steps 100000 \ + --output-dir $chairs_dir +``` + +``` +torchrun --nproc_per_node 8 --nnodes 1 train.py \ + --dataset-root $dataset_root \ + --name $name_things \ + --model raft_large \ + --train-dataset things \ + --batch-size 2 \ + --lr 0.000125 \ + --weight-decay 0.0001 \ + --num-steps 100000 \ + --freeze-batch-norm \ + --output-dir $things_dir\ + --resume $chairs_dir/$name_chairs.pth +``` + + +### Evaluation + +``` +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +``` + +This should give an epe of about 1.3822 on the clean pass and 2.7161 on the +final pass of Sintel. Results may vary slightly depending on the batch size and +the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`: + +``` +Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248 +Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964 +``` diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index eaf03fbe4f3..326f0be5f66 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -3,10 +3,16 @@ from pathlib import Path import torch +import torchvision.models.optical_flow import utils from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -from torchvision.models.optical_flow import raft_large, raft_small + +try: + from torchvision.prototype import models as PM + from torchvision.prototype.models import optical_flow as PMOF +except ImportError: + PM = PMOF = None def get_train_dataset(stage, dataset_root): @@ -125,6 +131,13 @@ def inner_loop(blob): def validate(model, args): val_datasets = args.val_dataset or [] + + if args.weights: + weights = PM.get_weight(args.weights) + preprocessing = weights.transforms() + else: + preprocessing = OpticalFlowPresetEval() + for name in val_datasets: if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. @@ -134,14 +147,14 @@ def validate(model, args): f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) - val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) + val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _validate( model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 ) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel( - root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() + root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing ) _validate( model, @@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s def main(args): utils.setup_ddp(args) - model = raft_small() if args.small else raft_large() + if args.weights: + model = PMOF.__dict__[args.model](weights=args.weights) + else: + model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = model.to(args.local_rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) @@ -306,7 +323,12 @@ def get_args_parser(add_help=True): "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." ) - parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") + parser.add_argument( + "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" + ) + # TODO: resume, pretrained, and weights should be in an exclusive arg group + parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument( "--num_flow_updates", diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 56e91bb3d48..87a269c7a41 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -91,7 +91,8 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5"] @@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn): "quantization": classification_fields + ["backend", "quantization", "unquantized"], "segmentation": ["categories", "mIoU", "acc"], "video": classification_fields, + "optical_flow": [], } module_name = model_fn.__module__.split(".")[-2] fields = set(defaults["all"] + defaults[module_name]) @@ -201,13 +203,18 @@ def test_old_vs_new_factory(model_fn, dev): if module_name == "detection": x = [x] + if module_name == "optical_flow": + args = [x, x] # RAFT model requires img1, img2 as input + else: + args = [x] + # compare with new model builder parameterized in the old fashion way try: model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) model_new = _build_model(model_fn, **kwargs).to(device=dev) except ModuleNotFoundError: pytest.skip(f"Model '{model_name}' not available in both modules.") - torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False) + torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False) def test_smoke(): diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index f653895598f..ff851b6382e 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,6 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import ConvNormActivation +from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -19,6 +20,9 @@ ) +_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} + + class ResidualBlock(nn.Module): """Slightly modified Residual block with extra relu and biases.""" @@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): hidden_state = torch.tanh(hidden_state) context = F.relu(context) - coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda() - coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda() + coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) flow_predictions = [] for _ in range(num_flow_updates): @@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): def _raft( *, + arch=None, + pretrained=False, + progress=False, # Feature encoder feature_encoder_layers, feature_encoder_block, @@ -560,7 +567,7 @@ def _raft( multiplier=0.25, # See comment in MaskPredictor about this ) - return RAFT( + model = RAFT( feature_encoder=feature_encoder, context_encoder=context_encoder, corr_block=corr_block, @@ -568,6 +575,11 @@ def _raft( mask_predictor=mask_predictor, **kwargs, # not really needed, all params should be consumed by now ) + if pretrained: + state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) + model.load_state_dict(state_dict) + + return model def raft_large(*, pretrained=False, progress=True, **kwargs): @@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): nn.Module: The model. """ - if pretrained: - raise ValueError("No checkpoint is available for raft_large") - return _raft( + arch="raft_large", + pretrained=pretrained, + progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), feature_encoder_block=ResidualBlock, @@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): nn.Module: The model. """ - if pretrained: raise ValueError("No checkpoint is available for raft_small") return _raft( + arch="raft_small", + pretrained=pretrained, + progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), feature_encoder_block=BottleneckBlock, diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 4dad4b3b6b1..4fc7e962864 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -4,12 +4,11 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock - -# from torchvision.prototype.transforms import RaftEval +from torchvision.prototype.transforms import RaftEval +from torchvision.transforms.functional import InterpolationMode from .._api import WeightsEnum - -# from .._api import Weights +from .._api import Weights from .._utils import handle_legacy_interface @@ -22,17 +21,33 @@ ) +_COMMON_META = {"interpolation": InterpolationMode.BILINEAR} + + class Raft_Large_Weights(WeightsEnum): - pass - # C_T_V1 = Weights( - # # Chairs + Things - # url="", - # transforms=RaftEval, - # meta={ - # "recipe": "", - # "epe": -1234, - # }, - # ) + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-things.pth) + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.4411, + "sintel_train_finalpass_epe": 2.7894, + }, + ) + + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.3822, + "sintel_train_finalpass_epe": 2.7161, + }, + ) # C_T_SKHT_V1 = Weights( # # Chairs + Things + Sintel fine-tuning, i.e.: @@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum): # }, # ) - # default = C_T_V1 + default = C_T_V2 class Raft_Small_Weights(WeightsEnum): @@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum): # default = C_T_V1 -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - weights(Raft_Large_weights, optinal): TODO not implemented yet + weights(Raft_Large_weights, optional): pretrained weights to use. progress (bool): If True, displays a progress bar of the download to stderr kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class to override any default. @@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * weights = Raft_Large_Weights.verify(weights) - return _raft( + model = _raft( # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), feature_encoder_block=ResidualBlock, @@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * **kwargs, ) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + @handle_legacy_interface(weights=("pretrained", None)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): @@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * weights = Raft_Small_Weights.verify(weights) - return _raft( + model = _raft( # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), feature_encoder_block=BottleneckBlock, @@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * use_mask_predictor=False, **kwargs, ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model