forked from sczhou/CodeFormer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
83 changed files
with
8,159 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1.3.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# https://github.com/xinntao/BasicSR | ||
# flake8: noqa | ||
from .archs import * | ||
from .data import * | ||
from .losses import * | ||
from .metrics import * | ||
from .models import * | ||
from .ops import * | ||
from .train import * | ||
from .utils import * | ||
from .version import __gitsha__, __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import importlib | ||
from copy import deepcopy | ||
from os import path as osp | ||
|
||
from basicsr.utils import get_root_logger, scandir | ||
from basicsr.utils.registry import ARCH_REGISTRY | ||
|
||
__all__ = ['build_network'] | ||
|
||
# automatically scan and import arch modules for registry | ||
# scan all the files under the 'archs' folder and collect files ending with | ||
# '_arch.py' | ||
arch_folder = osp.dirname(osp.abspath(__file__)) | ||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] | ||
# import all the arch modules | ||
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] | ||
|
||
|
||
def build_network(opt): | ||
opt = deepcopy(opt) | ||
network_type = opt.pop('type') | ||
net = ARCH_REGISTRY.get(network_type)(**opt) | ||
logger = get_root_logger() | ||
logger.info(f'Network [{net.__class__.__name__}] is created.') | ||
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
import torch.nn as nn | ||
from basicsr.utils.registry import ARCH_REGISTRY | ||
|
||
|
||
def conv3x3(inplanes, outplanes, stride=1): | ||
"""A simple wrapper for 3x3 convolution with padding. | ||
Args: | ||
inplanes (int): Channel number of inputs. | ||
outplanes (int): Channel number of outputs. | ||
stride (int): Stride in convolution. Default: 1. | ||
""" | ||
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
|
||
|
||
class BasicBlock(nn.Module): | ||
"""Basic residual block used in the ResNetArcFace architecture. | ||
Args: | ||
inplanes (int): Channel number of inputs. | ||
planes (int): Channel number of outputs. | ||
stride (int): Stride in convolution. Default: 1. | ||
downsample (nn.Module): The downsample module. Default: None. | ||
""" | ||
expansion = 1 # output channel expansion ratio | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class IRBlock(nn.Module): | ||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture. | ||
Args: | ||
inplanes (int): Channel number of inputs. | ||
planes (int): Channel number of outputs. | ||
stride (int): Stride in convolution. Default: 1. | ||
downsample (nn.Module): The downsample module. Default: None. | ||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. | ||
""" | ||
expansion = 1 # output channel expansion ratio | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): | ||
super(IRBlock, self).__init__() | ||
self.bn0 = nn.BatchNorm2d(inplanes) | ||
self.conv1 = conv3x3(inplanes, inplanes) | ||
self.bn1 = nn.BatchNorm2d(inplanes) | ||
self.prelu = nn.PReLU() | ||
self.conv2 = conv3x3(inplanes, planes, stride) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
self.use_se = use_se | ||
if self.use_se: | ||
self.se = SEBlock(planes) | ||
|
||
def forward(self, x): | ||
residual = x | ||
out = self.bn0(x) | ||
out = self.conv1(out) | ||
out = self.bn1(out) | ||
out = self.prelu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
if self.use_se: | ||
out = self.se(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.prelu(out) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
"""Bottleneck block used in the ResNetArcFace architecture. | ||
Args: | ||
inplanes (int): Channel number of inputs. | ||
planes (int): Channel number of outputs. | ||
stride (int): Stride in convolution. Default: 1. | ||
downsample (nn.Module): The downsample module. Default: None. | ||
""" | ||
expansion = 4 # output channel expansion ratio | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class SEBlock(nn.Module): | ||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock. | ||
Args: | ||
channel (int): Channel number of inputs. | ||
reduction (int): Channel reduction ration. Default: 16. | ||
""" | ||
|
||
def __init__(self, channel, reduction=16): | ||
super(SEBlock, self).__init__() | ||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information | ||
self.fc = nn.Sequential( | ||
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), | ||
nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
b, c, _, _ = x.size() | ||
y = self.avg_pool(x).view(b, c) | ||
y = self.fc(y).view(b, c, 1, 1) | ||
return x * y | ||
|
||
|
||
@ARCH_REGISTRY.register() | ||
class ResNetArcFace(nn.Module): | ||
"""ArcFace with ResNet architectures. | ||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. | ||
Args: | ||
block (str): Block used in the ArcFace architecture. | ||
layers (tuple(int)): Block numbers in each layer. | ||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. | ||
""" | ||
|
||
def __init__(self, block, layers, use_se=True): | ||
if block == 'IRBlock': | ||
block = IRBlock | ||
self.inplanes = 64 | ||
self.use_se = use_se | ||
super(ResNetArcFace, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.prelu = nn.PReLU() | ||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.layer1 = self._make_layer(block, 64, layers[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
self.bn4 = nn.BatchNorm2d(512) | ||
self.dropout = nn.Dropout() | ||
self.fc5 = nn.Linear(512 * 8 * 8, 512) | ||
self.bn5 = nn.BatchNorm1d(512) | ||
|
||
# initialization | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.xavier_normal_(m.weight) | ||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.xavier_normal_(m.weight) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def _make_layer(self, block, planes, num_blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) | ||
self.inplanes = planes | ||
for _ in range(1, num_blocks): | ||
layers.append(block(self.inplanes, planes, use_se=self.use_se)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.prelu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
x = self.bn4(x) | ||
x = self.dropout(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc5(x) | ||
x = self.bn5(x) | ||
|
||
return x |
Oops, something went wrong.