Skip to content

Commit

Permalink
Add transrac: add more utils & debug video-swin-transformer (towhee-i…
Browse files Browse the repository at this point in the history
…o#1536)

Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu authored Jul 8, 2022
1 parent 3954566 commit 8e3b9b7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
10 changes: 9 additions & 1 deletion tests/unittests/models/transrac/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest
import torch

from towhee.models.transrac import DenseMap
from towhee.models.transrac import DenseMap, SimilarityMatrix


class TestUtils(unittest.TestCase):
Expand All @@ -29,6 +29,14 @@ def test_dense_map(self):
out = dense_map(dummy_x)
self.assertTrue(out.shape == torch.Size([5]))

def test_similarity_matrix(self):
q = torch.rand(1, 3)
k = torch.rand(1, 3)
v = torch.rand(1, 3)
get_sim = SimilarityMatrix(input_dim=3)
sim = get_sim(q, k, v)
self.assertTrue(sim.shape == (1, 4, 1, 1))


if __name__ == "__main__":
unittest.main()
38 changes: 38 additions & 0 deletions towhee/models/transrac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# limitations under the License.

from torch import nn
from towhee.models.layers.attention import Attention


class DenseMap(nn.Module):
Expand Down Expand Up @@ -51,3 +52,40 @@ def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, out_dim, dropout=0.25)
def forward(self, x):
x = self.layers(x)
return x


class SimilarityMatrix(nn.Module):
"""
Build similarity matrix for TransRAC
"""

def __init__(self, num_heads=4, input_dim=512, model_dim=512):
super().__init__()

# self.dim_per_head = model_dim // num_heads
self.num_heads = num_heads
self.model_dim = model_dim
self.input_size = input_dim
self.linear_q = nn.Linear(self.input_size, model_dim)
self.linear_k = nn.Linear(self.input_size, model_dim)
self.linear_v = nn.Linear(self.input_size, model_dim)

self.attention = Attention(att_dropout=0.)
# self.out = nn.Linear(model_dim, model_dim)
# self.layer_norm = nn.LayerNorm(model_dim)

def forward(self, query, key, value, attn_mask=None):
batch_size = query.size(0)
num_heads = self.num_heads
# linear projection
query = self.linear_q(query)
key = self.linear_k(key)
value = self.linear_v(value)
# split by heads
query = query.reshape(batch_size, -1, num_heads, self.model_dim // self.num_heads).transpose(1, 2)
key = key.reshape(batch_size, -1, num_heads, self.model_dim // self.num_heads).transpose(1, 2)
value = value.reshape(batch_size, -1, num_heads, self.model_dim // self.num_heads).transpose(1, 2)
# similar_matrix :[B,H,F,F ]
matrix, _ = self.attention(query, key, value, attn_mask)

return matrix
4 changes: 4 additions & 0 deletions towhee/models/video_swin_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@
import einops
except ModuleNotFoundError:
os.system('pip install einops')

from .video_swin_transformer import *
from .get_configs import *
from .video_swin_transformer_block import *
25 changes: 15 additions & 10 deletions towhee/models/video_swin_transformer/video_swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,22 @@ def create_model(model_name: str = None, pretrained: bool = False,
if pretrained:
if model_name is None:
raise AssertionError("Fail to load pretrained model: no model name is specified.")
if model_name:
model_configs = get_configs.configs(model_name)
model = VideoSwinTransformer(pretrained=model_configs["pretrained"],
num_classes=model_configs["num_classes"],
embed_dim=model_configs["embed_dim"],
depths=model_configs["depths"],
num_heads=model_configs["num_heads"],
patch_size=model_configs["patch_size"],
window_size=model_configs["window_size"],
drop_path_rate=model_configs["drop_path_rate"],
patch_norm=model_configs["patch_norm"],
device=device)
model_configs = dict(pretrained=model_configs["pretrained"],
num_classes=model_configs["num_classes"],
embed_dim=model_configs["embed_dim"],
depths=model_configs["depths"],
num_heads=model_configs["num_heads"],
patch_size=model_configs["patch_size"],
window_size=model_configs["window_size"],
drop_path_rate=model_configs["drop_path_rate"],
patch_norm=model_configs["patch_norm"],
device=device)
if not pretrained:
model_configs["pretrained"] = None
model = VideoSwinTransformer(**model_configs)
else:
model = VideoSwinTransformer(**kwargs)

return model

0 comments on commit 8e3b9b7

Please sign in to comment.