Skip to content

Commit

Permalink
Add pretrained weights for raft_small from original paper (pytorch#5070)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 9, 2021
1 parent 4cacf5a commit 48e2f23
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
8 changes: 5 additions & 3 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
)


_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
_MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
# TODO: change to V2 once we upload our own weights
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
}


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -641,8 +645,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""
if pretrained:
raise ValueError("No checkpoint is available for raft_small")

return _raft(
arch="raft_small",
Expand Down
26 changes: 15 additions & 11 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,19 @@ class Raft_Large_Weights(WeightsEnum):


class Raft_Small_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# url="", # TODO
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# default = C_T_V1
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-small.pth)
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 2.1231,
"sintel_train_finalpass_epe": 3.2790,
},
)

default = C_T_V1 # TODO: Change to V2 once we upload our own weights


@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
Expand Down Expand Up @@ -140,7 +143,8 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
return model


@handle_legacy_interface(weights=("pretrained", None))
# TODO: change to V2 once we upload our own weights
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Expand Down

0 comments on commit 48e2f23

Please sign in to comment.