Skip to content

Commit

Permalink
bug fix in models
Browse files Browse the repository at this point in the history
  • Loading branch information
hwnam831 committed Jan 17, 2023
1 parent 805d619 commit f24c670
Showing 1 changed file with 77 additions and 1 deletion.
78 changes: 77 additions & 1 deletion Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, c_in, c_mid, dilation=1):

def forward(self, x):
out = self.block(x)
out += x
out = out + x
return out

class CNNModel(nn.Module):
Expand Down Expand Up @@ -412,6 +412,44 @@ def forward(self, x, distill=False):
else:
return torch.relu(out+noise)

class NoiseInjector(nn.Module):
def __init__(self, noiselevel=0.1):
super().__init__()
self.noiselevel = nn.Parameter(torch.ones(1)*noiselevel)

def forward(self, x):
return x + torch.randn_like(x)*self.noiselevel

class RNNGenerator3(nn.Module):
def __init__(self, threshold, scale=1, dim=128, window=32, drop=0.2):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(window, dim),
nn.ReLU(),
NoiseInjector(drop)
)
self.resblock = nn.GRU(dim,dim, num_layers=1, batch_first=True)

self.decoder = nn.Sequential(
NoiseInjector(drop),
nn.Linear(dim, 1),
nn.ReLU()
)
self.scale = nn.Parameter(torch.ones(1)*scale)

def forward(self, x, distill=False):

encoded = self.encoder(x.permute(0,2,1)) #N,C,S -> N,S,C

res, _ = self.resblock(encoded)
out = encoded + res #N,S,C
out = self.decoder(out).view(out.size(0),-1)

if distill:
return out, (encoded, res, out)
else:
return out

class MLPGen(nn.Module):
def __init__(self, threshold, scale=1, dim=128, window=32, drop=0.2, depth=2):
super().__init__()
Expand Down Expand Up @@ -547,6 +585,44 @@ def forward(self, x, distill=False):
else:
return torch.relu(out+noise)

class QGRU3(nn.Module):
def __init__(self, threshold, scale=1, dim=128, window=32, drop=0.2):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(window, dim),
NoiseInjector(scale)
)

self.gru1 = nn.GRUCell(dim,dim)

self.decoder = nn.Sequential(
NoiseInjector(scale),
nn.Linear(dim, 1),
nn.ReLU(),
)


def forward(self, x, distill=False):

src = x.permute(2,0,1) # N,C,S -> S,N,C

encoded = self.encoder(src)
h1 = torch.zeros_like(encoded[0])
hiddens = []
for item in encoded:
h1 = self.gru1(item, h1)
hiddens.append(h1)
res = torch.stack(hiddens)
out = encoded + res #S,N,C
out = self.decoder(out).view(out.size(0),-1)
out = out.permute(1,0)
#out = out + self.scale*torch.randn_like(out)
#out = out + noise

if distill:
return out, (encoded.permute(1,0,2), res.permute(1,0,2), out)
else:
return out

class Distiller(nn.Module):
def __init__(self, threshold, tdim=256, sdim=32, lamb_d = 0.1, lamb_r = 0.1, window=32):
Expand Down

0 comments on commit f24c670

Please sign in to comment.