Skip to content

Commit

Permalink
Add pretrained weights on Chairs and Things for raft_large (pytorch#5060
Browse files Browse the repository at this point in the history
)
NicolasHug authored Dec 8, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d416b2c commit 849d02b
Showing 5 changed files with 157 additions and 33 deletions.
57 changes: 57 additions & 0 deletions references/optical_flow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Optical flow reference training scripts

This folder contains reference training scripts for optical flow.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.


### RAFT Large

The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.

```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_chairs \
--model raft_large \
--train-dataset chairs \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--num-steps 100000 \
--output-dir $chairs_dir
```

```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model raft_large \
--train-dataset things \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--num-steps 100000 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
```


### Evaluation

```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
```

This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
final pass of Sintel. Results may vary slightly depending on the batch size and
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:

```
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
```
32 changes: 27 additions & 5 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,16 @@
from pathlib import Path

import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small

try:
from torchvision.prototype import models as PM
from torchvision.prototype.models import optical_flow as PMOF
except ImportError:
PM = PMOF = None


def get_train_dataset(stage, dataset_root):
@@ -125,6 +131,13 @@ def inner_loop(blob):

def validate(model, args):
val_datasets = args.val_dataset or []

if args.weights:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = OpticalFlowPresetEval()

for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
@@ -134,14 +147,14 @@ def validate(model, args):
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)

val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_validate(
model,
@@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
def main(args):
utils.setup_ddp(args)

model = raft_small() if args.small else raft_large()
if args.weights:
model = PMOF.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)

model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

@@ -306,7 +323,12 @@ def get_args_parser(add_help=True):
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)

parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")

parser.add_argument(
"--num_flow_updates",
11 changes: 9 additions & 2 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
@@ -91,7 +91,8 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"]
@@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn):
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
"segmentation": ["categories", "mIoU", "acc"],
"video": classification_fields,
"optical_flow": [],
}
module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name])
@@ -201,13 +203,18 @@ def test_old_vs_new_factory(model_fn, dev):
if module_name == "detection":
x = [x]

if module_name == "optical_flow":
args = [x, x] # RAFT model requires img1, img2 as input
else:
args = [x]

# compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)


def test_smoke():
28 changes: 21 additions & 7 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops import ConvNormActivation

from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once
from ._utils import grid_sample, make_coords_grid, upsample_flow

@@ -19,6 +20,9 @@
)


_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}


class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""

@@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
hidden_state = torch.tanh(hidden_state)
context = F.relu(context)

coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)

flow_predictions = []
for _ in range(num_flow_updates):
@@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12):

def _raft(
*,
arch=None,
pretrained=False,
progress=False,
# Feature encoder
feature_encoder_layers,
feature_encoder_block,
@@ -560,14 +567,19 @@ def _raft(
multiplier=0.25, # See comment in MaskPredictor about this
)

return RAFT(
model = RAFT(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_block=corr_block,
update_block=update_block,
mask_predictor=mask_predictor,
**kwargs, # not really needed, all params should be consumed by now
)
if pretrained:
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)

return model


def raft_large(*, pretrained=False, progress=True, **kwargs):
@@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""

if pretrained:
raise ValueError("No checkpoint is available for raft_large")

return _raft(
arch="raft_large",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
@@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""

if pretrained:
raise ValueError("No checkpoint is available for raft_small")

return _raft(
arch="raft_small",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
62 changes: 43 additions & 19 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
@@ -4,12 +4,11 @@
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock

# from torchvision.prototype.transforms import RaftEval
from torchvision.prototype.transforms import RaftEval
from torchvision.transforms.functional import InterpolationMode

from .._api import WeightsEnum

# from .._api import Weights
from .._api import Weights
from .._utils import handle_legacy_interface


@@ -22,17 +21,33 @@
)


_COMMON_META = {"interpolation": InterpolationMode.BILINEAR}


class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411,
"sintel_train_finalpass_epe": 2.7894,
},
)

C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161,
},
)

# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
@@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum):
# },
# )

# default = C_T_V1
default = C_T_V2


class Raft_Small_Weights(WeightsEnum):
@@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum):
# default = C_T_V1


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
weights(Raft_Large_weights, optional): pretrained weights to use.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
@@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *

weights = Raft_Large_Weights.verify(weights)

return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
@@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
@@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *

weights = Raft_Small_Weights.verify(weights)

return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
@@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
use_mask_predictor=False,
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model

0 comments on commit 849d02b

Please sign in to comment.