Skip to content

Commit

Permalink
Refactor towhee.models.layers (towhee-io#1514)
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu authored Jul 5, 2022
1 parent 9bee3e2 commit 122e899
Show file tree
Hide file tree
Showing 14 changed files with 306 additions and 130 deletions.
10 changes: 7 additions & 3 deletions tests/unittests/models/allinone/test_allinone.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

class AllinoneTest(unittest.TestCase):
def test_allinone(self):
vcoph = VCOPHeader()
x = torch.rand(16, 3, 768)
vcoph = VCOPHeader(feature_size=4)
x = torch.rand(2, 3, 4)
out = vcoph(x)
self.assertTrue(out.shape == torch.Size([16, 6]))
self.assertTrue(out.shape == torch.Size([2, 6]))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import unittest
import torch

from towhee.models.layers.multi_head_attention import MultiHeadAttention
from towhee.models.layers.attention import MultiHeadAttention, Attention


class MHATest(unittest.TestCase):
class AttentionTest(unittest.TestCase):
def test_mha_with_lrp(self):
seq_len = 21
c_dim = 10
Expand All @@ -30,3 +30,14 @@ def test_mha_with_lrp(self):
# torch.Size([8, 21, 10])
out2 = mod.relprop(out1, **kwargs)
self.assertTrue(out2.shape == torch.Size([8, 21, 10]))

def test_attention(self):
q = k = v = torch.ones((1, 3))
mod = Attention(scale=1, att_dropout=0.)
scores, context = mod(q, k, v)
self.assertTrue(scores == 1)
self.assertTrue((context == torch.tensor([1, 1, 1])).all())


if __name__ == '__main__':
unittest.main()
40 changes: 40 additions & 0 deletions tests/unittests/models/layers/test_position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
import torch

from towhee.models.layers.position_encoding import build_position_encoding


class TestPositionEncoding(unittest.TestCase):
"""
Test Transformer Encoder
"""
x = torch.rand(1, 2)

def test_sine(self):
pos_embed = build_position_encoding(hidden_dim=2*2, max_len=4, position_embedding='sine')
out = pos_embed(self.x)
self.assertTrue(out.shape == (1, 1, 2))

def test_learned(self):
pos_embed = build_position_encoding(hidden_dim=2*2, position_embedding='learned')
out = pos_embed(self.x)
self.assertTrue(out.shape == (1, 4, 1, 2))


if __name__ == '__main__':
unittest.main()
34 changes: 34 additions & 0 deletions tests/unittests/models/layers/test_transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
import torch

from towhee.models.layers.transformer_encoder import TransformerEncoder


class TransformerEncoderTest(unittest.TestCase):
"""
Test Transformer Encoder
"""
def test_transformer_encoder(self):
dummy_x = torch.rand(2, 4)
mode = TransformerEncoder(d_model=4, n_head=1, dim_ff=1, dropout=0.0, num_layers=1, num_frames=2)
out = mode(dummy_x)
self.assertTrue(out.shape == (2, 2, 4))


if __name__ == '__main__':
unittest.main()
41 changes: 24 additions & 17 deletions tests/unittests/models/omnivore/test_omnivore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,39 @@

import torch
import unittest
from towhee.models.omnivore.omnivore import omnivore_swins, omnivore_swint, omnivore_swinb_imagenet21k, omnivore_swinl_imagenet21k
from towhee.models.omnivore.omnivore import omnivore_swins, omnivore_swint, omnivore_swinb_imagenet21k, \
omnivore_swinl_imagenet21k


class OmnivoreTest(unittest.TestCase):
def test_omnivore_swins(self):
pretrained = False
model = omnivore_swins(pretrained = pretrained)
x = torch.randn(10,3,5,4,4)
y = model(x,"video")
self.assertTrue(y.shape == torch.Size([10,400]))
model = omnivore_swins(pretrained=pretrained)
x = torch.randn(1, 3, 5, 4, 4)
y = model(x, "video")
self.assertTrue(y.shape == torch.Size([1, 400]))

def test_omnivore_swint(self):
pretrained = False
model = omnivore_swint(pretrained = pretrained)
x = torch.randn(10,3,5,4,4)
y = model(x,"video")
self.assertTrue(y.shape == torch.Size([10,400]))
model = omnivore_swint(pretrained=pretrained)
x = torch.randn(1, 3, 5, 4, 4)
y = model(x, "video")
self.assertTrue(y.shape == torch.Size([1, 400]))

def test_omnivore_swinb_imagenet21k(self):
pretrained = False
model = omnivore_swinb_imagenet21k(pretrained = pretrained)
x = torch.randn(10,3,5,4,4)
y = model(x,"video")
self.assertTrue(y.shape == torch.Size([10,400]))
model = omnivore_swinb_imagenet21k(pretrained=pretrained)
x = torch.randn(1, 3, 5, 4, 4)
y = model(x, "video")
self.assertTrue(y.shape == torch.Size([1, 400]))

def test_omnivore_swinl_imagenet21k(self):
pretrained = False
model = omnivore_swinl_imagenet21k(pretrained = pretrained)
x = torch.randn(10,3,5,4,4)
y = model(x,"video")
self.assertTrue(y.shape == torch.Size([10,400]))
model = omnivore_swinl_imagenet21k(pretrained=pretrained)
x = torch.randn(1, 3, 5, 4, 4)
y = model(x, "video")
self.assertTrue(y.shape == torch.Size([1, 400]))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/unittests/models/violet/test_violet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class VioletTest(unittest.TestCase):
"""
def test_violet(self):
img = torch.rand(1, 5, 3, 32, 32)
txt = torch.randint(20230, size=(1, 5,))
txt = torch.randint(10, size=(1, 5,))
mask_i = [[1, 1, 1, 0, 0]]
mask = []
for i in range(0, 1):
Expand Down
12 changes: 7 additions & 5 deletions towhee/models/coformer/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from towhee.models.coformer.utils import NestedTensor, is_main_process
from towhee.models.layers.position_encoding import build_position_encoding


class BackboneBase(nn.Module):
"""
Args:
Expand Down Expand Up @@ -111,14 +112,15 @@ def forward(self, tensor_list: NestedTensor):

return out, pos


def build_backbone(
hidden_dim = 512,
position_embedding = "learned",
backbone = "resnet50",
hidden_dim=512,
position_embedding="learned",
backbone="resnet50",
):
position_embedding = build_position_encoding(
hidden_dim = hidden_dim,
position_embedding = "learned",
hidden_dim=hidden_dim,
position_embedding=position_embedding,
)
train_backbone = False
return_interm_layers = False
Expand Down
117 changes: 63 additions & 54 deletions towhee/models/coformer/coformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from towhee.models.coformer.transformer import build_transformer
from towhee.models.coformer.config import _C


class CoFormer(nn.Module):
"""CoFormer model for Grounded Situation Recognition"""

def __init__(self, backbone, transformer, num_noun_classes, vidx_ridx):
""" Initialize the model.
Parameters:
Expand Down Expand Up @@ -57,27 +59,26 @@ def __init__(self, backbone, transformer, num_noun_classes, vidx_ridx):
# classifiers & predictors (for grounded noun prediction)
self.noun_1_classifier = nn.Linear(hidden_dim, self.num_noun_classes)
self.noun_2_classifier = nn.Linear(hidden_dim, self.num_noun_classes)
self.noun_3_classifier = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim*2, self.num_noun_classes))
self.bbox_predictor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim*2, hidden_dim*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim*2, 4))
self.bbox_conf_predictor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim*2, 1))
self.noun_3_classifier = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 2, self.num_noun_classes))
self.bbox_predictor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim * 2, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim * 2, 4))
self.bbox_conf_predictor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim * 2, 1))

# layer norms
self.ln1 = nn.LayerNorm(hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim)


def forward(self, samples, targets=None, inference=False):
"""
Parameters:
Expand All @@ -100,32 +101,34 @@ def forward(self, samples, targets=None, inference=False):
# model prediction
for i in range(batch_size):
if not inference:
outs = self.transformer(self.input_proj(src[i:i+1]),
mask[i:i+1], self.il_token_embed.weight, self.rl_token_embed.weight,
outs = self.transformer(self.input_proj(src[i:i + 1]),
mask[i:i + 1], self.il_token_embed.weight, self.rl_token_embed.weight,
self.verb_token_embed.weight, self.role_token_embed.weight,
pos[-1][i:i+1], self.vidx_ridx, targets=targets[i], inference=inference)
pos[-1][i:i + 1], self.vidx_ridx, targets=targets[i], inference=inference)
else:
outs = self.transformer(self.input_proj(src[i:i+1]),
mask[i:i+1], self.il_token_embed.weight, self.rl_token_embed.weight,
self.verb_token_embed.weight, self.role_token_embed.weight,
pos[-1][i:i+1], self.vidx_ridx, inference=inference)
outs = self.transformer(self.input_proj(src[i:i + 1]),
mask[i:i + 1], self.il_token_embed.weight, self.rl_token_embed.weight,
self.verb_token_embed.weight, self.role_token_embed.weight,
pos[-1][i:i + 1], self.vidx_ridx, inference=inference)

# output features & predictions
verb_pred, extracted_rhs, aggregated_rhs, final_rhs, selected_roles = outs[0], outs[1], outs[2], outs[3], outs[4]
verb_pred, extracted_rhs, aggregated_rhs, final_rhs, selected_roles = outs[0], outs[1], outs[2], outs[3], \
outs[4]
num_selected_roles = len(selected_roles)
## auxiliary classifiers
# auxiliary classifiers
if not inference:
extracted_rhs = self.ln1(extracted_rhs[:, :, selected_roles, :])
noun_1_pred = self.noun_1_classifier(extracted_rhs)
noun_1_pred = F.pad(noun_1_pred,
(0,0,0,max_num_roles-num_selected_roles),
(0, 0, 0, max_num_roles - num_selected_roles),
mode='constant',
value=0,
)[-1].view(1,max_num_roles,self.num_noun_classes)
aggregated_rhs = self.ln2(aggregated_rhs[selected_roles].permute(1,0,2).view(1, 1, num_selected_roles, -1))
)[-1].view(1, max_num_roles, self.num_noun_classes)
aggregated_rhs = self.ln2(
aggregated_rhs[selected_roles].permute(1, 0, 2).view(1, 1, num_selected_roles, -1))
noun_2_pred = self.noun_2_classifier(aggregated_rhs)
noun_2_pred = F.pad(noun_2_pred,
(0,0,0,max_num_roles-num_selected_roles),
(0, 0, 0, max_num_roles - num_selected_roles),
mode='constant',
value=0,
)[-1].view(1, max_num_roles, self.num_noun_classes)
Expand All @@ -134,14 +137,19 @@ def forward(self, samples, targets=None, inference=False):
noun_2_pred = None
noun_3_pred = self.noun_3_classifier(final_rhs)
noun_3_pred = F.pad(noun_3_pred,
(0,0,0,max_num_roles-num_selected_roles),
(0, 0, 0, max_num_roles - num_selected_roles),
mode='constant',
value=0,
)[-1].view(1, max_num_roles, self.num_noun_classes)
bbox_pred = self.bbox_predictor(final_rhs).sigmoid()
bbox_pred = F.pad(bbox_pred, (0,0,0,max_num_roles-num_selected_roles), mode='constant', value=0)[-1].view(1, max_num_roles, 4)
bbox_pred = F.pad(bbox_pred, (0, 0, 0, max_num_roles - num_selected_roles), mode='constant', value=0)[
-1].view(1, max_num_roles, 4)
bbox_conf_pred = self.bbox_conf_predictor(final_rhs)
bbox_conf_pred = F.pad(bbox_conf_pred, (0,0,0,max_num_roles-num_selected_roles), mode='constant', value=0)[-1].view(1, max_num_roles, 1)
bbox_conf_pred = \
F.pad(bbox_conf_pred, (0, 0, 0, max_num_roles - num_selected_roles), mode='constant', value=0)[-1].view(
1,
max_num_roles,
1)

batch_verb.append(verb_pred)
batch_noun_1.append(noun_1_pred)
Expand All @@ -162,37 +170,38 @@ def forward(self, samples, targets=None, inference=False):

return out


def create_model(
model_name: str = None,
vidx_ridx = None,
device = None,
):
model_name: str = None,
vidx_ridx=None,
device=None,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_name == 'coformer':
model_config = _C.MODEL.CoFormer
else:
raise AttributeError(f'Invalid model_name {model_name}.')
backbone = build_backbone(
hidden_dim = model_config.hidden_dim,
position_embedding = model_config.position_embedding,
backbone = model_config.backbone,
)
hidden_dim=model_config.hidden_dim,
position_embedding=model_config.position_embedding,
backbone=model_config.backbone,
)
transformer = build_transformer(
d_model = model_config.hidden_dim,
dropout = model_config.dropout,
nhead = model_config.nhead,
num_glance_enc_layers = model_config.num_glance_enc_layers,
num_gaze_s1_dec_layers = model_config.num_gaze_s1_dec_layers,
num_gaze_s1_enc_layers = model_config.num_gaze_s1_enc_layers,
num_gaze_s2_dec_layers = model_config.num_gaze_s2_dec_layers,
dim_feedforward = model_config.dim_feedforward,
)
d_model=model_config.hidden_dim,
dropout=model_config.dropout,
nhead=model_config.nhead,
num_glance_enc_layers=model_config.num_glance_enc_layers,
num_gaze_s1_dec_layers=model_config.num_gaze_s1_dec_layers,
num_gaze_s1_enc_layers=model_config.num_gaze_s1_enc_layers,
num_gaze_s2_dec_layers=model_config.num_gaze_s2_dec_layers,
dim_feedforward=model_config.dim_feedforward,
)
model = CoFormer(
backbone,
transformer,
num_noun_classes = model_config.num_noun_classes,
vidx_ridx = vidx_ridx,
)
backbone,
transformer,
num_noun_classes=model_config.num_noun_classes,
vidx_ridx=vidx_ridx,
)
model.to(device)
return model
Loading

0 comments on commit 122e899

Please sign in to comment.