diff --git a/EAMLP/~ b/EAMLP/~ deleted file mode 100644 index 4512d04..0000000 --- a/EAMLP/~ +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. -# -# This source code is licensed under the Clear BSD License -# LICENSE file in the root directory of this file -# All rights reserved. -""" -Borrow from timm(https://github.com/rwightman/pytorch-image-models) -""" -import torch -import torch.nn as nn -import numpy as np -from timm.models.layers import DropPath - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): - super().__init__() - self.num_heads = num_heads - self.dim = dim - - self.mid_dim = self.dim // num_heads - self.k0 = 64 - self.k1 = 64 - self.k2 = 64 - self.k3 = 64 - self.k4 = 64 - self.k5 = 64 - - self.scale = qk_scale or self.mid_dim ** -0.5 - - self.linear_10 = nn.Linear(self.mid_dim, self.k0, bias=False) - self.linear_11 = nn.Linear(self.mid_dim, self.k1, bias=False) - self.linear_12 = nn.Linear(self.mid_dim, self.k2, bias=False) - self.linear_13 = nn.Linear(self.mid_dim, self.k3, bias=False) - self.linear_14 = nn.Linear(self.mid_dim, self.k4, bias=False) - self.linear_15 = nn.Linear(self.mid_dim, self.k5, bias=False) - - self.linear_20 = nn.Linear(self.k0, self.mid_dim, bias=False) - self.linear_21 = nn.Linear(self.k1, self.mid_dim, bias=False) - self.linear_22 = nn.Linear(self.k2, self.mid_dim, bias=False) - self.linear_23 = nn.Linear(self.k3, self.mid_dim, bias=False) - self.linear_24 = nn.Linear(self.k4, self.mid_dim, bias=False) - self.linear_25 = nn.Linear(self.k5, self.mid_dim, bias=False) - - self.linear_20.weight.data = self.linear_10.weight.data.permute(1, 0) - self.linear_21.weight.data = self.linear_11.weight.data.permute(1, 0) - self.linear_22.weight.data = self.linear_12.weight.data.permute(1, 0) - self.linear_23.weight.data = self.linear_13.weight.data.permute(1, 0) - self.linear_24.weight.data = self.linear_14.weight.data.permute(1, 0) - self.linear_25.weight.data = self.linear_15.weight.data.permute(1, 0) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(self.mid_dim * num_heads, self.dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - delta = C // self.num_heads - x0 = x[:,:,:delta] - x1 = x[:,:,delta: delta * 2] - x2 = x[:,:,delta * 2: delta * 3] - x3 = x[:,:,delta * 3: delta * 4] - x4 = x[:,:,delta * 4: delta * 5] - x5 = x[:,:,delta * 5:] - - x0 = self.linear_10(x0) * self.scale - x0 = x0.softmax(dim=1) - x0 = x0 / (1e-9 + x0.sum(dim=-1, keepdim=True)) # - x0 = self.attn_drop(x0) - x0 = self.linear_20(x0) - - x1 = self.linear_11(x1) * self.scale - x1 = x1.softmax(dim=1) - x1 = x1 / (1e-9 + x1.sum(dim=-1, keepdim=True)) # - x1 = self.attn_drop(x1) - x1 = self.linear_21(x1) - - x2 = self.linear_12(x2) * self.scale - x2 = x2.softmax(dim=1) - x2 = x2 / (1e-9 + x2.sum(dim=-1, keepdim=True)) # - x2 = self.attn_drop(x2) - x2 = self.linear_22(x2) - - x3 = self.linear_13(x3) * self.scale - x3 = x3.softmax(dim=1) - x3 = x3 / (1e-9 + x3.sum(dim=-1, keepdim=True)) # - x3 = self.attn_drop(x3) - x3 = self.linear_23(x3) - - ''' - x4 = self.linear_14(x4) * self.scale - x4 = x4.softmax(dim=1) - x4 = x4 / (1e-9 + x4.sum(dim=-1, keepdim=True)) # - x4 = self.attn_drop(x4) - x4 = self.linear_24(x4) - - x5 = self.linear_15(x5) * self.scale - x5 = x5.softmax(dim=1) - x5 = x5 / (1e-9 + x5.sum(dim=-1, keepdim=True)) # - x5 = self.attn_drop(x5) - x5 = self.linear_25(x5) - ''' - x = torch.cat([x0, x1, x2, x3], dim=-1) - x = self.proj(x) - x = self.proj_drop(x) - - return x -''' - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - - self.scale = qk_scale or head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - -''' - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -def get_sinusoid_encoding(n_position, d_hid): - ''' Sinusoid position encoding table ''' - - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0)