Skip to content

Commit

Permalink
Update cvae_keras.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored Jun 10, 2019
1 parent ccec3d5 commit 9afaa51
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cvae_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import mnist
from keras.utils import to_categorical


batch_size = 100
original_dim = 784
latent_dim = 2 # 隐变量取2维只是为了方便后面画图
Expand Down Expand Up @@ -69,7 +69,7 @@ def sampling(args):
vae = Model([x, y], [x_decoded_mean, yh])

# xent_loss是重构loss,kl_loss是KL loss
xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
xent_loss = K.sum(K.binary_crossentropy(x, x_decoded_mean), axis=-1)

# 只需要修改K.square(z_mean)为K.square(z_mean - yh),也就是让隐变量向类内均值看齐
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean - yh) - K.exp(z_log_var), axis=-1)
Expand Down

0 comments on commit 9afaa51

Please sign in to comment.