Skip to content

Commit

Permalink
update SPADE block
Browse files Browse the repository at this point in the history
Modify SPADE architecture to be consistent with the official implementation.
  • Loading branch information
shaoanlu committed May 10, 2019
1 parent c563edc commit b2351de
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions networks/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def calc_loss(pred, target, loss='l2'):

def cyclic_loss(netG1, netG2, real1):
fake2 = netG2(real1)[-1] # fake2 ABGR
fake2_alpha = Lambda(lambda x: x[:,:,:, :1])(fake2) # fake2 BGR
fake2_alpha = Lambda(lambda x: x[:,:,:, :1])(fake2)
fake2 = Lambda(lambda x: x[:,:,:, 1:])(fake2) # fake2 BGR
cyclic1 = netG1(fake2)[-1] # cyclic1 ABGR
cyclic1_alpha = Lambda(lambda x: x[:,:,:, :1])(cyclic1) # cyclic1 BGR
cyclic1_alpha = Lambda(lambda x: x[:,:,:, :1])(cyclic1)
cyclic1 = Lambda(lambda x: x[:,:,:, 1:])(cyclic1) # cyclic1 BGR
loss = calc_loss(cyclic1, real1, loss='l1')
loss += 0.1 * calc_loss(cyclic1_alpha, fake2_alpha, loss='l1')
Expand Down
2 changes: 1 addition & 1 deletion networks/nn_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def SPADE(input_tensor, cond_input_tensor, f, use_norm=True, norm='none'):
kernel_initializer=conv_init, padding='same')(y)
beta = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2),
kernel_initializer=conv_init, padding='same')(y)
x = multiply([x, gamma])
x = add([x, multiply([x, gamma])])
x = add([x, beta])
return x

Expand Down

0 comments on commit b2351de

Please sign in to comment.