-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathConditional_ResBlock.py
27 lines (24 loc) · 1.11 KB
/
Conditional_ResBlock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch import nn
# from ops.Conditional_BN import Conditional_BN
from components.Adain import Adain
class Conditional_ResBlock(nn.Module):
def __init__(self, in_channel, k_size = 3, n_class = 2, stride=1):
super().__init__()
padding_size = int((k_size -1)/2)
self.same_padding1 = nn.ReplicationPad2d(padding_size)
self.conv1 = nn.Conv2d(in_channels = in_channel , out_channels = in_channel, kernel_size= k_size, stride=stride, bias= False)
self.adain1 = Adain(in_channel,n_class)
self.same_padding2 = nn.ReplicationPad2d(padding_size)
self.conv2 = nn.Conv2d(in_channels = in_channel , out_channels = in_channel, kernel_size= k_size, stride=stride, bias= False)
self.adain2 = Adain(in_channel,n_class)
def forward(self, input, condition):
res = input
h = self.same_padding1(input)
h = self.conv1(h)
h = self.adain1(h,condition)
h = self.same_padding2(h)
h = self.conv2(h)
h = self.adain2(h,condition)
out = h + res
return out