Skip to content

Commit

Permalink
Use Pix2Pix-based LocalEncoder.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxie committed Jun 19, 2023
1 parent 06806b3 commit af441c0
Showing 1 changed file with 89 additions and 111 deletions.
200 changes: 89 additions & 111 deletions models/gancraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @Author: Haozhe Xie
# @Date: 2023-04-12 19:53:21
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-06-15 15:11:46
# @Last Modified at: 2023-06-19 19:46:23
# @Email: [email protected]
# @Ref: https://github.com/FrozenBurning/SceneDreamer

Expand Down Expand Up @@ -489,48 +489,102 @@ class LocalEncoder(torch.nn.Module):
def __init__(self, cfg):
super(LocalEncoder, self).__init__()
n_classes = cfg.DATASETS.OSM_LAYOUT.N_CLASSES
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3)
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
self.seg_conv = torch.nn.Conv2d(
n_classes, 32, kernel_size=7, stride=2, padding=3
n_classes, 32, kernel_size=4, stride=2, padding=1
)
if cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "BATCH_NORM":
self.bn1 = torch.nn.BatchNorm2d(64)
elif cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "GROUP_NORM":
self.bn1 = torch.nn.GroupNorm(32, 64)
else:
raise ValueError(
"Unknown normalization: %s" % cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM
)
self.conv2 = ResConvBlock(64, 128, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
self.conv3 = ResConvBlock(128, 256, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
self.conv4 = ResConvBlock(256, 512, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
self.dconv5 = torch.nn.ConvTranspose2d(
512, 128, kernel_size=4, stride=2, padding=1
self.conv1 = torch.nn.Sequential(
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(128),
)
self.conv2 = torch.nn.Sequential(
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(256),
)
self.conv3 = torch.nn.Sequential(
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(512),
)
self.conv4 = torch.nn.Sequential(
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(512),
)
self.conv5 = torch.nn.Sequential(
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
)
self.dconv5 = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
torch.nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(512),
torch.nn.Dropout2d(p=0.5, inplace=True),
)
self.dconv4 = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(512),
torch.nn.Dropout2d(p=0.5, inplace=True),
)
self.dconv6 = torch.nn.ConvTranspose2d(
128, 32, kernel_size=4, stride=2, padding=1
self.dconv3 = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
torch.nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(256),
)
self.dconv7 = torch.nn.Conv2d(
32, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, kernel_size=1
self.dconv2 = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
torch.nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(128),
)
self.dconv1 = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
torch.nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1),
torch.nn.BatchNorm2d(64),
)
self.dconv0 = torch.nn.ConvTranspose2d(
128,
cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1,
kernel_size=4,
stride=2,
padding=1,
)

def forward(self, hf_seg):
def forward(self, hf_seg, z=None):
hf = self.hf_conv(hf_seg[:, [0]])
seg = self.seg_conv(hf_seg[:, 1:])
out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True)
# print(out.size()) # torch.Size([N, 64, H/2, W/2])
out = F.avg_pool2d(self.conv2(out), 2, stride=2)
# print(out.size()) # torch.Size([N, 128, H/4, W/4])
out = self.conv3(out)
# print(out.size()) # torch.Size([N, 256, H/4, W/4])
out = self.conv4(out)
# print(out.size()) # torch.Size([N, 512, H/4, W/4])
out = self.dconv5(out)
# print(out.size()) # torch.Size([N, 128, H/2, W/2])
out = self.dconv6(out)
# print(out.size()) # torch.Size([N, 32, H, W])
out = self.dconv7(out)
# print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W])
out0 = torch.cat([hf, seg], dim=1)
# print(out0.size()) # torch.Size([N, 64, H/2, W/2])
out1 = self.conv1(out0)
# print(out1.size()) # torch.Size([N, 128, H/4, W/4])
out2 = self.conv2(out1)
# print(out2.size()) # torch.Size([N, 256, H/8, W/8])
out3 = self.conv3(out2)
# print(out3.size()) # torch.Size([N, 512, H/16, W/16])
out4 = self.conv4(out3)
# print(out4.size()) # torch.Size([N, 512, H/32, W/32])
out5 = self.conv5(out4)
# print(out5.size()) # torch.Size([N, 512, H/64, W/64])
out = self.dconv5(out5)
# print(out.size()) # torch.Size([N, 512, H/32, W/32])
out = torch.cat([out, out4], dim=1)
# print(out.size()) # torch.Size([N, 1024, H/32, W/32])
out = self.dconv4(out)
out = torch.cat([out, out3], dim=1)
# print(out.size()) # torch.Size([1, 1024, H/16, W/16])
out = self.dconv3(out)
out = torch.cat([out, out2], dim=1)
# print(out.size()) # torch.Size([1, 512, H/8, W/8])
out = self.dconv2(out)
out = torch.cat([out, out1], dim=1)
# print(out.size()) # torch.Size([1, 256, H/4, W/4])
out = self.dconv1(out)
out = torch.cat([out, out0], dim=1)
# print(out.size()) # torch.Size([1, 128, H/2, W/2])
out = self.dconv0(out)
# print(out.size()) # torch.Size([1, OUT_DIM - 1, H, W])
return torch.tanh(out)


Expand Down Expand Up @@ -813,82 +867,6 @@ def forward(self, x):
return self.layers(x)


class ResConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, norm, bias=False):
super(ResConvBlock, self).__init__()
# conv3x3(in_planes, int(out_planes / 2))
self.conv1 = torch.nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
# conv3x3(int(out_planes / 2), int(out_planes / 4))
self.conv2 = torch.nn.Conv2d(
out_channels // 2,
out_channels // 4,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
# conv3x3(int(out_planes / 4), int(out_planes / 4))
self.conv3 = torch.nn.Conv2d(
out_channels // 4,
out_channels // 4,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
if norm == "BATCH_NORM":
self.bn1 = torch.nn.BatchNorm2d(in_channels)
self.bn2 = torch.nn.BatchNorm2d(out_channels // 2)
self.bn3 = torch.nn.BatchNorm2d(out_channels // 4)
self.bn4 = torch.nn.BatchNorm2d(in_channels)
elif norm == "GROUP_NORM":
self.bn1 = torch.nn.GroupNorm(32, in_channels)
self.bn2 = torch.nn.GroupNorm(32, out_channels // 2)
self.bn3 = torch.nn.GroupNorm(32, out_channels // 4)
self.bn4 = torch.nn.GroupNorm(32, in_channels)

if in_channels != out_channels:
self.downsample = torch.nn.Sequential(
self.bn4,
torch.nn.ReLU(True),
torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=False
),
)
else:
self.downsample = None

def forward(self, x):
residual = x
# print(residual.size()) # torch.Size([N, 64, H, W])
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
# print(out1.size()) # torch.Size([N, 64, H, W])
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
# print(out2.size()) # torch.Size([N, 32, H, W])
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
# print(out3.size()) # torch.Size([N, 32, H, W])
out3 = torch.cat((out1, out2, out3), dim=1)
# print(out3.size()) # torch.Size([N, 128, H, W])
if self.downsample is not None:
residual = self.downsample(residual)
# print(residual.size()) # torch.Size([N, 128, H, W])
out3 += residual
return out3


class ModLinear(torch.nn.Module):
r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
Expand Down

0 comments on commit af441c0

Please sign in to comment.