Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Jun 29, 2022
1 parent 882efe5 commit b1f7d78
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 32 deletions.
2 changes: 1 addition & 1 deletion multi_part_assembly/models/modules/vnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .modules import VNLinear, VNBatchNorm, VNLayerNorm, VNReLU, VNLeakyReLU, VNLinearBNLeakyReLU, VNMaxPool, VNInFeature
from .transformer import VNTransformerEncoderLayer
from .transformer import VNTransformerEncoderLayer, VNSelfAttention
24 changes: 11 additions & 13 deletions multi_part_assembly/models/modules/vnn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def forward(self, x):
Returns:
features of the same shape after LN in each instance
"""
B, C = x.shape[0], x.shape[-2]
B, C = x.shape[:2]
ori_shape = x.shape
x = x.transpose(-1, 1).reshape(B, -1, 3, C)
x = self.ln(x)
Expand Down Expand Up @@ -237,16 +237,14 @@ def forward(self, x):


""" test code
from multi_part_assembly.utils import random_quaternions, qrot
pc = torch.randn(2, 1000, 3)
quat = random_quaternions(2)
rot_pc = qrot(quat, pc)
pc = pc.reshape(2, 10, 100, 3).permute(0, 1, 3, 2) # [B, C, 3, N, ...]
rot_pc = rot_pc.reshape(2, 10, 100, 3).permute(0, 1, 3, 2) # same
model = VNLinear(10, 20)
out = model(pc).transpose(-1, -2).reshape(2, 2000, 3)
rot_out = model(rot_pc).transpose(-1, -2).reshape(2, 2000, 3)
(qrot(quat, out) - rot_out).abs().max()
import torch
from multi_part_assembly.models import VNLayerNorm
from multi_part_assembly.utils import random_rotation_matrixs
vn_ln = VNLayerNorm(16)
pc = torch.rand(2, 16, 3, 100)
rmat = random_rotation_matrixs(2)
rot_pc = rmat[:, None] @ pc
ln_pc = vn_ln(pc)
rot_ln_pc = rmat[:, None] @ ln_pc
ln_rot_pc = vn_ln(rot_pc)
"""
32 changes: 26 additions & 6 deletions multi_part_assembly/models/modules/vnn/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@


class VNSelfAttention(nn.Module):
"""Inspired by VNT-Net: https://arxiv.org/pdf/2205.09690.pdf."""
"""Inspired by VNT-Net: https://arxiv.org/pdf/2205.09690.pdf.
def __init__(self, d_model, n_head, dropout=0.1):
Note that, we cannot use dropout in VN networks.
"""

def __init__(self, d_model, n_head, dropout=0.):
super().__init__()

assert d_model % n_head == 0
assert dropout == 0.
self.n_head = n_head

# key, query, value projections for all heads
Expand Down Expand Up @@ -43,11 +48,11 @@ def forward(self, x, src_key_padding_mask=None):

# [B, nh, N, hs*3]
k = self.key(x).reshape(B, self.n_head, C // self.n_head, 3, N).\
permute(0, 1, 4, 2, 3).flatten(-1, -2)
permute(0, 1, 4, 2, 3).flatten(-2, -1)
q = self.query(x).reshape(B, self.n_head, C // self.n_head, 3, N).\
permute(0, 1, 4, 2, 3).flatten(-1, -2)
permute(0, 1, 4, 2, 3).flatten(-2, -1)
v = self.value(x).reshape(B, self.n_head, C // self.n_head, 3, N).\
permute(0, 1, 4, 2, 3).flatten(-1, -2)
permute(0, 1, 4, 2, 3).flatten(-2, -1)

# [B, nh, N, hs*3] x [B, nh, N, hs*3] --> [B, nh, N, N]
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
Expand All @@ -69,9 +74,10 @@ def forward(self, x, src_key_padding_mask=None):
class VNTransformerEncoderLayer(nn.Module):
"""VN Transformer block."""

def __init__(self, d_model, n_head, relu=True, dropout=0.1):
def __init__(self, d_model, n_head, relu=True, dropout=0.):
super().__init__()

assert dropout == 0.
self.ln1 = VNLayerNorm(d_model)
self.ln2 = VNLayerNorm(d_model)
self.attn = VNSelfAttention(
Expand Down Expand Up @@ -100,3 +106,17 @@ def forward(self, x, src_key_padding_mask=None, src_mask=None):
x = x + self.attn(self.ln1(x), src_key_padding_mask)
x = x + self.mlp(self.ln2(x))
return x


""" test code
import torch
from multi_part_assembly.models import VNSelfAttention
from multi_part_assembly.utils import random_rotation_matrixs
vn_attn = VNSelfAttention(16, 1, 0)
pc = torch.rand(2, 16, 3, 100)
rmat = random_rotation_matrixs(2)
rot_pc = rmat[:, None] @ pc
attn_pc = vn_attn(pc)
rot_attn_pc = rmat[:, None] @ attn_pc
attn_rot_pc = vn_attn(rot_pc)
"""
3 changes: 2 additions & 1 deletion multi_part_assembly/utils/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def rot_type(self):

@rot_type.setter
def rot_type(self, rot_type):
raise NotImplementedError('cannot change rotation type')
raise NotImplementedError(
'please use convert() for rotation type conversion')

@property
def shape(self):
Expand Down
35 changes: 24 additions & 11 deletions multi_part_assembly/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

import torch

from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_matrix
from pytorch3d.transforms import quaternion_invert, quaternion_apply, quaternion_raw_multiply
from pytorch3d.transforms import quaternion_invert, quaternion_apply, \
quaternion_raw_multiply
from pytorch3d.transforms import random_quaternions as _random_quaternions
from pytorch3d.transforms import matrix_to_quaternion, matrix_to_axis_angle, \
quaternion_to_matrix, quaternion_to_axis_angle, \
axis_angle_to_quaternion, axis_angle_to_matrix

from .rotation import Rotation3D

Expand Down Expand Up @@ -103,7 +106,6 @@ def qtransform(t, q, v):

qv = qrot(q, v)
tqv = qv + t

return tqv


Expand All @@ -112,18 +114,34 @@ def qtransform_invert(t, q, tqv):
assert t.shape[-1] == 3
if len(t.shape) == len(tqv.shape) - 1:
t = t.unsqueeze(-2).repeat_interleave(tqv.shape[-2], dim=-2)

assert t.shape == tqv.shape
qv = tqv - t

qv = tqv - t
q_inv = quaternion_invert(q)
v = qrot(q_inv, qv)

return v


# rmat-based transformations


def random_rotation_matrixs(shape):
"""
Generate random rotation matrixs representing rotations.
We apply quat2rmat on random quaternions.
Args:
shape: [N1, N2, ...]
Returns:
Rotation matrixs as tensor of shape (N1, N2, ..., 3, 3).
"""
quat = random_quaternions(shape)
return quaternion_to_matrix(quat)


def rmatq(r):
"""
Convert quaternion(s) q to rotation matrix(s).
Expand All @@ -150,11 +168,7 @@ def rmat_rot(r, v):

assert r.shape[:-2] == v.shape[:-1]

original_shape = list(v.shape)
r = r.view(-1, 3, 3)
v = v.view(-1, 3, 1)

rv = torch.bmm(r, v).view(original_shape)
rv = (r @ v.unsqueeze(-1)).squeeze(-1)
return rv


Expand All @@ -176,7 +190,6 @@ def rmat_transform(t, r, v):

rv = rmat_rot(r, v)
trv = rv + t

return trv


Expand Down

0 comments on commit b1f7d78

Please sign in to comment.