Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more variant and pretrained_weight capability (#130) #154

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 101 additions & 37 deletions test/models/test_omnivore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,110 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import pytest
import torch
import torchmultimodal.models.omnivore as omnivore

from test.test_utils import set_rng_seed
from test.test_utils import assert_expected, set_rng_seed
from torchmultimodal.utils.common import get_current_device


class TestOmnivoreModel(unittest.TestCase):
def setUp(self):
set_rng_seed(42)
self.device = get_current_device()

def test_omnivore_swin_t_forward(self):
model = omnivore.omnivore_swin_t().to(self.device)
self.assertTrue(isinstance(model, torch.nn.Module))

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
image_score = model(image, input_type="image")
self.assertEqual(image_score.size(), torch.Size((1, 1000)))
self.assertAlmostEqual(image_score.abs().sum().item(), 200.27572, 2)

rgbd = torch.randn(1, 4, 1, 112, 112)
rgbd_score = model(rgbd, input_type="rgbd")
self.assertEqual(rgbd_score.size(), torch.Size((1, 19)))
self.assertAlmostEqual(rgbd_score.abs().sum().item(), 3.10466, 3)

video = torch.randn(1, 3, 4, 112, 112)
video_score = model(video, input_type="video")
self.assertEqual(video_score.size(), torch.Size((1, 400)))
self.assertAlmostEqual(video_score.abs().sum().item(), 97.57287, 2)

def test_omnivore_forward_wrong_input_type(self):
model = omnivore.omnivore_swin_t().to(self.device)

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
with self.assertRaises(AssertionError) as cm:
_ = model(image, input_type="_WRONG_TYPE_")
self.assertEqual(
"Unsupported input_type: _WRONG_TYPE_, please use one of {'video', 'rgbd', 'image'}",
str(cm.exception),
)
@pytest.fixture(autouse=True)
def device():
set_rng_seed(42)
return get_current_device()


@pytest.fixture(autouse=True)
def omnivore_swin_t_model(device):
return omnivore.omnivore_swin_t().to(device)


@pytest.fixture(autouse=True)
def omnivore_swin_s_model(device):
return omnivore.omnivore_swin_s().to(device)


@pytest.fixture(autouse=True)
def omnivore_swin_b_model(device):
return omnivore.omnivore_swin_b().to(device)


def test_omnivore_swin_t_forward(omnivore_swin_t_model):
model = omnivore_swin_t_model

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
image_score = model(image, input_type="image")

assert_expected(image_score.size(), torch.Size((1, 1000)))
assert_expected(
image_score.abs().sum(), torch.tensor(194.83563), rtol=1e-3, atol=1e-3
)

rgbd = torch.randn(1, 4, 1, 112, 112)
rgbd_score = model(rgbd, input_type="rgbd")
assert_expected(rgbd_score.size(), torch.Size((1, 19)))
assert_expected(rgbd_score.abs().sum(), torch.tensor(3.18015), rtol=1e-3, atol=1e-3)

video = torch.randn(1, 3, 4, 112, 112)
video_score = model(video, input_type="video")
assert_expected(video_score.size(), torch.Size((1, 400)))
assert_expected(
video_score.abs().sum(), torch.tensor(100.87259), rtol=1e-3, atol=1e-3
)


def test_omnivore_swin_s_forward(omnivore_swin_s_model):
model = omnivore_swin_s_model

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
image_score = model(image, input_type="image")

assert_expected(image_score.size(), torch.Size((1, 1000)))
assert_expected(
image_score.abs().sum(), torch.tensor(240.41123), rtol=1e-3, atol=1e-3
)

rgbd = torch.randn(1, 4, 1, 112, 112)
rgbd_score = model(rgbd, input_type="rgbd")
assert_expected(rgbd_score.size(), torch.Size((1, 19)))
assert_expected(rgbd_score.abs().sum(), torch.tensor(5.73624), rtol=1e-3, atol=1e-3)

video = torch.randn(1, 3, 4, 112, 112)
video_score = model(video, input_type="video")
assert_expected(video_score.size(), torch.Size((1, 400)))
assert_expected(
video_score.abs().sum(), torch.tensor(100.75939), rtol=1e-3, atol=1e-3
)


def test_omnivore_swin_b_forward(omnivore_swin_b_model):
model = omnivore_swin_b_model

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
image_score = model(image, input_type="image")

assert_expected(image_score.size(), torch.Size((1, 1000)))
assert_expected(
image_score.abs().sum(), torch.tensor(293.43484), rtol=1e-3, atol=1e-3
)

rgbd = torch.randn(1, 4, 1, 112, 112)
rgbd_score = model(rgbd, input_type="rgbd")
assert_expected(rgbd_score.size(), torch.Size((1, 19)))
assert_expected(rgbd_score.abs().sum(), torch.tensor(6.76342), rtol=1e-3, atol=1e-3)

video = torch.randn(1, 3, 4, 112, 112)
video_score = model(video, input_type="video")
assert_expected(video_score.size(), torch.Size((1, 400)))
assert_expected(
video_score.abs().sum(), torch.tensor(131.65342), rtol=1e-3, atol=1e-3
)


def test_omnivore_forward_wrong_input_type(omnivore_swin_t_model):
model = omnivore_swin_t_model

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
with pytest.raises(AssertionError, match="Unsupported input_type: _WRONG_TYPE_.+"):
_ = model(image, input_type="_WRONG_TYPE_")
6 changes: 2 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any

import torch
import torch.distributed as dist
from torch import Tensor


def gpu_test(gpu_count: int = 1):
Expand Down Expand Up @@ -71,9 +71,7 @@ def get_asset_path(file_name: str) -> str:
return str(_ASSET_DIR.joinpath(file_name))


def assert_expected(
actual: Tensor, expected: Tensor, rtol: float = None, atol: float = None
):
def assert_expected(actual: Any, expected: Any, rtol: float = None, atol: float = None):
torch.testing.assert_close(
actual,
expected,
Expand Down
1 change: 1 addition & 0 deletions test/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

import pytest

Expand Down
158 changes: 136 additions & 22 deletions torchmultimodal/models/omnivore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,25 @@
# LICENSE file in the root directory of this source tree.


from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional

import torch
from torch import nn, Tensor
import torchmultimodal.utils.common as common_utils
from torch import nn
from torchmultimodal.modules.encoders.swin_transformer_3d_encoder import (
PatchEmbed3d,
SwinTransformer3d,
)

_OMNIVORE_PRETRAINED_URLS = {
"swin_t_encoder": "https://download.pytorch.org/models/omnivore_swin_t_encoder-b7e39400.pth",
"swin_s_encoder": "https://download.pytorch.org/models/omnivore_swin_s_encoder-40b05ba1.pth",
"swin_b_encoder": "https://download.pytorch.org/models/omnivore_swin_b_encoder-a9134768.pth",
"swin_t_heads": "https://download.pytorch.org/models/omnivore_swin_t_heads-c8bfb7fd.pth",
"swin_s_heads": "https://download.pytorch.org/models/omnivore_swin_s_heads-c5e77246.pth",
"swin_b_heads": "https://download.pytorch.org/models/omnivore_swin_b_heads-3c38b3ed.pth",
}


def _imagenet1k_head(input_dim: int) -> nn.Module:
return nn.Linear(input_dim, 1000, bias=True)
Expand Down Expand Up @@ -46,22 +56,24 @@ class Omnivore(nn.Module):
Omnivore (https://arxiv.org/abs/2201.08377) is a single model that able to do classification
on images, videos, and single-view 3D data using the same shared parameters of the encoder.

Args: encoder (nn.Module): Instantiated encoder.
See SwinTransformer3dEncoder class.
heads (Optinal[nn.ModuleDict]): Dictionary of multiple heads for each dataset type
Args:
encoder (nn.Module): Instantiated encoder. It generally accept a video backbone.
The paper use SwinTransformer3d for the encoder.
heads (Optional[nn.ModuleDict]): Dictionary of multiple heads for each dataset type

Inputs: x (Tensor): 5 Dimensional batched video tensor with format of B C D H W
where B is batch, C is channel, D is time, H is height, and W is width.
input_type (str): The dataset type of the input, this will used to choose
the correct head.
Inputs:
x (Tensor): 5 Dimensional batched video tensor with format of B C D H W
where B is batch, C is channel, D is time, H is height, and W is width.
input_type (str): The dataset type of the input, this will used to choose
the correct head.
"""

def __init__(self, encoder: nn.Module, heads: nn.ModuleDict):
super().__init__()
self.encoder = encoder
self.heads = heads

def forward(self, x: torch.Tensor, input_type: str) -> Tensor:
def forward(self, x: torch.Tensor, input_type: str) -> torch.Tensor:
x = self.encoder(x)
assert (
input_type in self.heads
Expand All @@ -78,9 +90,9 @@ class PatchEmbedOmnivore(nn.Module):
reference: https://arxiv.org/abs/2201.08377

Args:
patch_size (Tuple[int, int, int]): Patch token size. Default: (2, 4, 4)
embed_dim (int): Number of linear projection output channels. Default: 96
norm_layer (nn.Module, optional): Normalization layer. Default: None
patch_size (Tuple[int, int, int]): Patch token size. Default: ``(2, 4, 4)``
embed_dim (int): Number of linear projection output channels. Default: ``96``
norm_layer (nn.Module, optional): Normalization layer. Default: ``None``
"""

def __init__(
Expand Down Expand Up @@ -118,7 +130,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def _omnivore_swin_t_encoder() -> SwinTransformer3d:
def omnivore_swin_t_encoder(
pretrained: bool = False, progress: bool = True
) -> SwinTransformer3d:
encoder = SwinTransformer3d(
patch_size=[2, 4, 4],
embed_dim=96,
Expand All @@ -130,16 +144,116 @@ def _omnivore_swin_t_encoder() -> SwinTransformer3d:
patch_embed=PatchEmbedOmnivore,
num_classes=None,
)
if pretrained:
common_utils.load_module_from_url(
encoder,
_OMNIVORE_PRETRAINED_URLS["swin_t_encoder"],
progress=progress,
)
return encoder


# TODO: add pretrained weight capability
def omnivore_swin_t(
encoder_only: bool = False,
) -> Union[Omnivore, SwinTransformer3d]:
encoder = _omnivore_swin_t_encoder()
if encoder_only:
return encoder
def omnivore_swin_s_encoder(
pretrained: bool = False, progress: bool = True
) -> SwinTransformer3d:
encoder = SwinTransformer3d(
patch_size=[2, 4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
stochastic_depth_prob=0.3,
norm_layer=nn.LayerNorm,
patch_embed=PatchEmbedOmnivore,
num_classes=None,
)
if pretrained:
common_utils.load_module_from_url(
encoder,
_OMNIVORE_PRETRAINED_URLS["swin_s_encoder"],
progress=progress,
)
return encoder


def omnivore_swin_b_encoder(
pretrained: bool = False, progress: bool = True
) -> SwinTransformer3d:
encoder = SwinTransformer3d(
patch_size=[2, 4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[16, 7, 7],
stochastic_depth_prob=0.3,
norm_layer=nn.LayerNorm,
patch_embed=PatchEmbedOmnivore,
num_classes=None,
)
if pretrained:
common_utils.load_module_from_url(
encoder,
_OMNIVORE_PRETRAINED_URLS["swin_b_encoder"],
progress=progress,
)
return encoder


def omnivore_swin_t(pretrained: bool = False, progress: bool = True) -> nn.Module:
"""
Builder function to get omnivore model with swin_t variant encoder
Args:
pretrained (bool): If true then the it will load pretrained weight,
otherwise it will have random weight (default: ``False``)
progress (bool): If true then there will be a progress bar for downloading weight (default: ``True``)
"""
encoder = omnivore_swin_t_encoder(pretrained=pretrained)
heads = _multimodal_head(input_dim=encoder.num_features)
return Omnivore(encoder, heads)
if pretrained:
common_utils.load_module_from_url(
heads,
_OMNIVORE_PRETRAINED_URLS["swin_t_heads"],
progress=progress,
)
model = Omnivore(encoder, heads)
return model


def omnivore_swin_s(pretrained: bool = False, progress: bool = True) -> nn.Module:
"""
Builder function to get omnivore model with swin_s variant encoder
Args:
pretrained (bool): If true then the it will load pretrained weight,
otherwise it will have random weight (default: ``False``)
progress (bool): If true then there will be a progress bar for downloading weight (default: ``True``)
"""
encoder = omnivore_swin_s_encoder(pretrained=pretrained)
heads = _multimodal_head(input_dim=encoder.num_features)
if pretrained:
common_utils.load_module_from_url(
heads,
_OMNIVORE_PRETRAINED_URLS["swin_s_heads"],
progress=progress,
)
model = Omnivore(encoder, heads)
return model


def omnivore_swin_b(pretrained: bool = False, progress: bool = True) -> nn.Module:
"""
Builder function to get omnivore model with swin_b variant encoder
Args:
pretrained (bool): If true then the it will load pretrained weight,
otherwise it will have random weight (default: ``False``)
progress (bool): If true then there will be a progress bar for downloading weight (default: ``True``)
"""
encoder = omnivore_swin_b_encoder(pretrained=pretrained)
heads = _multimodal_head(input_dim=encoder.num_features)
if pretrained:
common_utils.load_module_from_url(
heads,
_OMNIVORE_PRETRAINED_URLS["swin_b_heads"],
progress=progress,
)
model = Omnivore(encoder, heads)
return model
Loading