Skip to content

Commit

Permalink
canonical vn transformer encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Jul 13, 2022
1 parent 018f745 commit 37a04c3
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 8 deletions.
3 changes: 2 additions & 1 deletion multi_part_assembly/models/modules/vnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .modules import VNLinear, VNBatchNorm, VNLayerNorm, VNReLU, VNLeakyReLU, VNLinearBNLeakyReLU, VNMaxPool, VNInFeature
from .modules import VNLinear, VNBatchNorm, VNLayerNorm, VNReLU, VNLeakyReLU, \
VNLinearBNLeakyReLU, VNMaxPool, VNInFeature, VNEqFeature
from .transformer import VNTransformerEncoderLayer, VNSelfAttention
35 changes: 35 additions & 0 deletions multi_part_assembly/models/modules/vnn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,41 @@ def forward(self, x):
return x_in


class VNEqFeature(VNInFeature):
"""Map VN-IN features back to their original rotation."""

def forward(self, x, x_in):
"""
Args:
x: point features of shape [B, C, 3, N, ...]
x_in: rotation invariant features of shape [B, C, 3, N, ...]
Returns:
rotation equivariant features with x mapped from x_in
"""
if self.dim in [4, 5]:
dim = -1 if self.dim == 4 else (-1, -2)
x_mean = x.mean(dim=dim, keepdim=True).expand(x.size())
x = torch.cat((x, x_mean), dim=1)

z = x
z = self.vn1(z)
z = self.vn2(z)
z = self.vn_lin(z)
# z = z.transpose(1, 2).contiguous()

if self.dim == 4:
x_eq = torch.einsum('bijm,bjkm->bikm', x_in, z)
elif self.dim == 3:
x_eq = torch.einsum('bij,bjk->bik', x_in, z)
elif self.dim == 5:
x_eq = torch.einsum('bijmn,bjkmn->bikmn', x_in, z)
else:
raise NotImplementedError(f'dim={self.dim} is not supported')

return x_eq


""" test code
import torch
from multi_part_assembly.models import VNLayerNorm
Expand Down
2 changes: 1 addition & 1 deletion multi_part_assembly/models/pn_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .transformer import TransformerEncoder, VNTransformerEncoder
from .transformer import TransformerEncoder, VNTransformerEncoder, CanonicalVNTransformerEncoder
from .network import PNTransformer
from .network_gan import PNTransformerGAN
from .network_refine import PNTransformerRefine
Expand Down
72 changes: 69 additions & 3 deletions multi_part_assembly/models/pn_transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.nn as nn

from multi_part_assembly.models import VNTransformerEncoderLayer, \
VNLayerNorm, VNLinear
VNLayerNorm, VNLinear, VNInFeature, VNEqFeature


def build_transformer_encoder(
Expand All @@ -10,6 +10,7 @@ def build_transformer_encoder(
ffn_dim,
num_layers,
norm_first=True,
dropout=0.1,
):
"""Build the Transformer Encoder.
Expand All @@ -26,6 +27,7 @@ def build_transformer_encoder(
d_model=d_model,
nhead=num_heads,
dim_feedforward=ffn_dim,
dropout=dropout,
norm_first=norm_first,
batch_first=True,
)
Expand All @@ -45,6 +47,7 @@ def __init__(
ffn_dim,
num_layers,
norm_first=True,
dropout=0.1,
out_dim=None,
):
super().__init__()
Expand All @@ -55,6 +58,7 @@ def __init__(
ffn_dim=ffn_dim,
num_layers=num_layers,
norm_first=norm_first,
dropout=dropout,
)
self.out_fc = nn.Linear(d_model, out_dim) if \
out_dim is not None else nn.Identity()
Expand All @@ -78,6 +82,70 @@ def forward(self, tokens, valid_masks):
return self.out_fc(out)


class CanonicalVNTransformerEncoder(TransformerEncoder):
"""VNTransformer encoder with padding_mask.
It first maps tokens to invariant features.
Then, it applies the normal TransformerEncoder to perform interactions.
Finally, it maps the invariant features back to the rotation of tokens.
"""

def __init__(
self,
d_model,
num_heads,
num_layers,
dropout=0.,
out_dim=None,
):
super().__init__(
d_model=d_model * 3,
num_heads=num_heads,
ffn_dim=d_model * 3 * 4,
num_layers=num_layers,
norm_first=True,
dropout=dropout,
out_dim=out_dim,
)

self.feats_in = VNInFeature(d_model, dim=4)
self.feats_eq = VNEqFeature(d_model, dim=4)

def forward(self, tokens, valid_masks):
"""Forward pass.
Args:
tokens: [B, C, 3, N]
valid_masks: [B, N], True for valid, False for padded
Returns:
torch.Tensor: [B, C, 3, N]
"""
# map tokens to invariant features
tokens_in = self.feats_in(tokens).flatten(1, 2) # [B, C*3, N]
tokens_in = tokens_in.transpose(1, 2).contiguous() # [B, N, C*3]
out_in = super().forward(tokens_in, valid_masks) # [B, N, C*3]
# back to [B, C, 3, N]
out_in = out_in.transpose(1, 2).unflatten(1, (-1, 3)).contiguous()
out_eq = self.feats_eq(tokens, out_in)
return out_eq


""" test code
import torch
from multi_part_assembly.models import CanonicalVNTransformerEncoder, VNTransformerEncoder
from multi_part_assembly.utils import random_rotation_matrixs
vn_trans = CanonicalVNTransformerEncoder(16, 4, 2, 0.)
pc = torch.rand(2, 16, 3, 100)
rmat = random_rotation_matrixs((2, 16)) # [2, 16, 3, 3]
rot_pc = rmat @ pc
trans_pc = vn_trans(pc)
rot_trans_pc = rmat @ trans_pc
trans_rot_pc = vn_trans(rot_pc)
(rot_trans_pc - trans_rot_pc).abs().max()
"""


def build_vn_transformer_encoder(
d_model,
num_heads,
Expand Down Expand Up @@ -115,7 +183,6 @@ def __init__(
d_model,
num_heads,
num_layers,
relu=True,
dropout=0.,
out_dim=None,
):
Expand All @@ -125,7 +192,6 @@ def __init__(
d_model=d_model,
num_heads=num_heads,
num_layers=num_layers,
relu=relu,
dropout=dropout,
)
self.out_fc = VNLinear(d_model, out_dim, dim=4) if \
Expand Down
5 changes: 2 additions & 3 deletions multi_part_assembly/models/pn_transformer/vn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from multi_part_assembly.models import build_encoder, VNPoseRegressor

from .network import PNTransformer
from .transformer import VNTransformerEncoder
from .transformer import CanonicalVNTransformerEncoder


class VNPNTransformer(PNTransformer):
Expand Down Expand Up @@ -41,11 +41,10 @@ def _init_encoder(self):

def _init_corr_module(self):
"""Part feature interaction module."""
corr_module = VNTransformerEncoder(
corr_module = CanonicalVNTransformerEncoder(
d_model=self.pc_feat_dim,
num_heads=self.cfg.model.transformer_heads,
num_layers=self.cfg.model.transformer_layers,
relu=self.cfg.model.get('transformer_relu', True),
dropout=0.,
)
return corr_module
Expand Down

0 comments on commit 37a04c3

Please sign in to comment.