-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
89 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
# @Author: Haozhe Xie | ||
# @Date: 2023-04-12 19:53:21 | ||
# @Last Modified by: Haozhe Xie | ||
# @Last Modified at: 2023-06-15 15:11:46 | ||
# @Last Modified at: 2023-06-19 19:46:23 | ||
# @Email: [email protected] | ||
# @Ref: https://github.com/FrozenBurning/SceneDreamer | ||
|
||
|
@@ -489,48 +489,102 @@ class LocalEncoder(torch.nn.Module): | |
def __init__(self, cfg): | ||
super(LocalEncoder, self).__init__() | ||
n_classes = cfg.DATASETS.OSM_LAYOUT.N_CLASSES | ||
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3) | ||
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1) | ||
self.seg_conv = torch.nn.Conv2d( | ||
n_classes, 32, kernel_size=7, stride=2, padding=3 | ||
n_classes, 32, kernel_size=4, stride=2, padding=1 | ||
) | ||
if cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "BATCH_NORM": | ||
self.bn1 = torch.nn.BatchNorm2d(64) | ||
elif cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "GROUP_NORM": | ||
self.bn1 = torch.nn.GroupNorm(32, 64) | ||
else: | ||
raise ValueError( | ||
"Unknown normalization: %s" % cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM | ||
) | ||
self.conv2 = ResConvBlock(64, 128, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) | ||
self.conv3 = ResConvBlock(128, 256, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) | ||
self.conv4 = ResConvBlock(256, 512, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) | ||
self.dconv5 = torch.nn.ConvTranspose2d( | ||
512, 128, kernel_size=4, stride=2, padding=1 | ||
self.conv1 = torch.nn.Sequential( | ||
torch.nn.LeakyReLU(0.2, inplace=True), | ||
torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(128), | ||
) | ||
self.conv2 = torch.nn.Sequential( | ||
torch.nn.LeakyReLU(0.2, inplace=True), | ||
torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(256), | ||
) | ||
self.conv3 = torch.nn.Sequential( | ||
torch.nn.LeakyReLU(0.2, inplace=True), | ||
torch.nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(512), | ||
) | ||
self.conv4 = torch.nn.Sequential( | ||
torch.nn.LeakyReLU(0.2, inplace=True), | ||
torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(512), | ||
) | ||
self.conv5 = torch.nn.Sequential( | ||
torch.nn.LeakyReLU(0.2, inplace=True), | ||
torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), | ||
) | ||
self.dconv5 = torch.nn.Sequential( | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(512), | ||
torch.nn.Dropout2d(p=0.5, inplace=True), | ||
) | ||
self.dconv4 = torch.nn.Sequential( | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(512), | ||
torch.nn.Dropout2d(p=0.5, inplace=True), | ||
) | ||
self.dconv6 = torch.nn.ConvTranspose2d( | ||
128, 32, kernel_size=4, stride=2, padding=1 | ||
self.dconv3 = torch.nn.Sequential( | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(256), | ||
) | ||
self.dconv7 = torch.nn.Conv2d( | ||
32, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, kernel_size=1 | ||
self.dconv2 = torch.nn.Sequential( | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(128), | ||
) | ||
self.dconv1 = torch.nn.Sequential( | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1), | ||
torch.nn.BatchNorm2d(64), | ||
) | ||
self.dconv0 = torch.nn.ConvTranspose2d( | ||
128, | ||
cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, | ||
kernel_size=4, | ||
stride=2, | ||
padding=1, | ||
) | ||
|
||
def forward(self, hf_seg): | ||
def forward(self, hf_seg, z=None): | ||
hf = self.hf_conv(hf_seg[:, [0]]) | ||
seg = self.seg_conv(hf_seg[:, 1:]) | ||
out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True) | ||
# print(out.size()) # torch.Size([N, 64, H/2, W/2]) | ||
out = F.avg_pool2d(self.conv2(out), 2, stride=2) | ||
# print(out.size()) # torch.Size([N, 128, H/4, W/4]) | ||
out = self.conv3(out) | ||
# print(out.size()) # torch.Size([N, 256, H/4, W/4]) | ||
out = self.conv4(out) | ||
# print(out.size()) # torch.Size([N, 512, H/4, W/4]) | ||
out = self.dconv5(out) | ||
# print(out.size()) # torch.Size([N, 128, H/2, W/2]) | ||
out = self.dconv6(out) | ||
# print(out.size()) # torch.Size([N, 32, H, W]) | ||
out = self.dconv7(out) | ||
# print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W]) | ||
out0 = torch.cat([hf, seg], dim=1) | ||
# print(out0.size()) # torch.Size([N, 64, H/2, W/2]) | ||
out1 = self.conv1(out0) | ||
# print(out1.size()) # torch.Size([N, 128, H/4, W/4]) | ||
out2 = self.conv2(out1) | ||
# print(out2.size()) # torch.Size([N, 256, H/8, W/8]) | ||
out3 = self.conv3(out2) | ||
# print(out3.size()) # torch.Size([N, 512, H/16, W/16]) | ||
out4 = self.conv4(out3) | ||
# print(out4.size()) # torch.Size([N, 512, H/32, W/32]) | ||
out5 = self.conv5(out4) | ||
# print(out5.size()) # torch.Size([N, 512, H/64, W/64]) | ||
out = self.dconv5(out5) | ||
# print(out.size()) # torch.Size([N, 512, H/32, W/32]) | ||
out = torch.cat([out, out4], dim=1) | ||
# print(out.size()) # torch.Size([N, 1024, H/32, W/32]) | ||
out = self.dconv4(out) | ||
out = torch.cat([out, out3], dim=1) | ||
# print(out.size()) # torch.Size([1, 1024, H/16, W/16]) | ||
out = self.dconv3(out) | ||
out = torch.cat([out, out2], dim=1) | ||
# print(out.size()) # torch.Size([1, 512, H/8, W/8]) | ||
out = self.dconv2(out) | ||
out = torch.cat([out, out1], dim=1) | ||
# print(out.size()) # torch.Size([1, 256, H/4, W/4]) | ||
out = self.dconv1(out) | ||
out = torch.cat([out, out0], dim=1) | ||
# print(out.size()) # torch.Size([1, 128, H/2, W/2]) | ||
out = self.dconv0(out) | ||
# print(out.size()) # torch.Size([1, OUT_DIM - 1, H, W]) | ||
return torch.tanh(out) | ||
|
||
|
||
|
@@ -813,82 +867,6 @@ def forward(self, x): | |
return self.layers(x) | ||
|
||
|
||
class ResConvBlock(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels, norm, bias=False): | ||
super(ResConvBlock, self).__init__() | ||
# conv3x3(in_planes, int(out_planes / 2)) | ||
self.conv1 = torch.nn.Conv2d( | ||
in_channels, | ||
out_channels // 2, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
bias=bias, | ||
) | ||
# conv3x3(int(out_planes / 2), int(out_planes / 4)) | ||
self.conv2 = torch.nn.Conv2d( | ||
out_channels // 2, | ||
out_channels // 4, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
bias=bias, | ||
) | ||
# conv3x3(int(out_planes / 4), int(out_planes / 4)) | ||
self.conv3 = torch.nn.Conv2d( | ||
out_channels // 4, | ||
out_channels // 4, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
bias=bias, | ||
) | ||
if norm == "BATCH_NORM": | ||
self.bn1 = torch.nn.BatchNorm2d(in_channels) | ||
self.bn2 = torch.nn.BatchNorm2d(out_channels // 2) | ||
self.bn3 = torch.nn.BatchNorm2d(out_channels // 4) | ||
self.bn4 = torch.nn.BatchNorm2d(in_channels) | ||
elif norm == "GROUP_NORM": | ||
self.bn1 = torch.nn.GroupNorm(32, in_channels) | ||
self.bn2 = torch.nn.GroupNorm(32, out_channels // 2) | ||
self.bn3 = torch.nn.GroupNorm(32, out_channels // 4) | ||
self.bn4 = torch.nn.GroupNorm(32, in_channels) | ||
|
||
if in_channels != out_channels: | ||
self.downsample = torch.nn.Sequential( | ||
self.bn4, | ||
torch.nn.ReLU(True), | ||
torch.nn.Conv2d( | ||
in_channels, out_channels, kernel_size=1, stride=1, bias=False | ||
), | ||
) | ||
else: | ||
self.downsample = None | ||
|
||
def forward(self, x): | ||
residual = x | ||
# print(residual.size()) # torch.Size([N, 64, H, W]) | ||
out1 = self.bn1(x) | ||
out1 = F.relu(out1, True) | ||
out1 = self.conv1(out1) | ||
# print(out1.size()) # torch.Size([N, 64, H, W]) | ||
out2 = self.bn2(out1) | ||
out2 = F.relu(out2, True) | ||
out2 = self.conv2(out2) | ||
# print(out2.size()) # torch.Size([N, 32, H, W]) | ||
out3 = self.bn3(out2) | ||
out3 = F.relu(out3, True) | ||
out3 = self.conv3(out3) | ||
# print(out3.size()) # torch.Size([N, 32, H, W]) | ||
out3 = torch.cat((out1, out2, out3), dim=1) | ||
# print(out3.size()) # torch.Size([N, 128, H, W]) | ||
if self.downsample is not None: | ||
residual = self.downsample(residual) | ||
# print(residual.size()) # torch.Size([N, 128, H, W]) | ||
out3 += residual | ||
return out3 | ||
|
||
|
||
class ModLinear(torch.nn.Module): | ||
r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod). | ||
Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across | ||
|