forked from bojone/vae
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
'''用Keras实现的CVAE | ||
目前只保证支持Tensorflow后端 | ||
#来自 | ||
https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py | ||
''' | ||
|
||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from scipy.stats import norm | ||
|
||
from keras.layers import Input, Dense, Lambda | ||
from keras.models import Model | ||
from keras import backend as K | ||
from keras import metrics | ||
from keras.datasets import mnist | ||
from keras.utils import to_categorical | ||
|
||
batch_size = 100 | ||
original_dim = 784 | ||
latent_dim = 2 # 隐变量取2维只是为了方便后面画图 | ||
intermediate_dim = 256 | ||
epochs = 100 | ||
epsilon_std = 1.0 | ||
num_classes = 10 | ||
|
||
# 加载MNIST数据集 | ||
(x_train, y_train), (x_test_, y_test_) = mnist.load_data() | ||
x_train = x_train.astype('float32') / 255. | ||
x_test = x_test.astype('float32') / 255. | ||
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))) | ||
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))) | ||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
|
||
|
||
x = Input(shape=(original_dim,)) | ||
h = Dense(intermediate_dim, activation='relu')(x) | ||
|
||
# 算p(Z|X)的均值和方差 | ||
z_mean = Dense(latent_dim)(h) | ||
z_log_var = Dense(latent_dim)(h) | ||
|
||
y = Input(shape=(num_classes,)) # 输入类别 | ||
yh = Dense(latent_dim)(y) # 这里就是直接构建每个类别的均值 | ||
|
||
# 重参数技巧 | ||
def sampling(args): | ||
z_mean, z_log_var = args | ||
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., | ||
stddev=epsilon_std) | ||
return z_mean + K.exp(z_log_var / 2) * epsilon | ||
|
||
# 重参数层,相当于给输入加入噪声 | ||
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var]) | ||
|
||
# 解码层,也就是生成器部分 | ||
decoder_h = Dense(intermediate_dim, activation='relu') | ||
decoder_mean = Dense(original_dim, activation='sigmoid') | ||
h_decoded = decoder_h(z) | ||
x_decoded_mean = decoder_mean(h_decoded) | ||
|
||
# 建立模型 | ||
vae = Model([x, y], [x_decoded_mean, yh]) | ||
|
||
# xent_loss是重构loss,kl_loss是KL loss | ||
xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean) | ||
|
||
# 只需要修改K.square(z_mean)为K.square(z_mean - yh),也就是让隐变量向类内均值看齐 | ||
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean - yh) - K.exp(z_log_var), axis=-1) | ||
vae_loss = K.mean(xent_loss + kl_loss) | ||
|
||
# add_loss是新增的方法,用于更灵活地添加各种loss | ||
vae.add_loss(vae_loss) | ||
vae.compile(optimizer='rmsprop') | ||
vae.summary() | ||
|
||
vae.fit([x_train, y_train], | ||
shuffle=True, | ||
epochs=epochs, | ||
batch_size=batch_size, | ||
validation_data=([x_test, y_test], None)) | ||
|
||
|
||
# 构建encoder,然后观察各个数字在隐空间的分布 | ||
encoder = Model(x, z_mean) | ||
|
||
x_test_encoded = encoder.predict(x_test, batch_size=batch_size) | ||
plt.figure(figsize=(6, 6)) | ||
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test_) | ||
plt.colorbar() | ||
plt.show() | ||
|
||
# 构建生成器 | ||
decoder_input = Input(shape=(latent_dim,)) | ||
_h_decoded = decoder_h(decoder_input) | ||
_x_decoded_mean = decoder_mean(_h_decoded) | ||
generator = Model(decoder_input, _x_decoded_mean) | ||
|
||
# 输出每个类的均值向量 | ||
mu = Model(y, yh) | ||
mu = mu.predict(np.eye(num_classes)) | ||
|
||
# 观察能否通过控制隐变量的均值来输出特定类别的数字 | ||
n = 15 # figure with 15x15 digits | ||
digit_size = 28 | ||
figure = np.zeros((digit_size * n, digit_size * n)) | ||
|
||
output_digit = 9 | ||
|
||
#用正态分布的分位数来构建隐变量对 | ||
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) + mu[output_digit][1] | ||
grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) + mu[output_digit][0] | ||
|
||
for i, yi in enumerate(grid_x): | ||
for j, xi in enumerate(grid_y): | ||
z_sample = np.array([[xi, yi]]) | ||
x_decoded = generator.predict(z_sample) | ||
digit = x_decoded[0].reshape(digit_size, digit_size) | ||
figure[i * digit_size: (i + 1) * digit_size, | ||
j * digit_size: (j + 1) * digit_size] = digit | ||
|
||
plt.figure(figsize=(10, 10)) | ||
plt.imshow(figure, cmap='Greys_r') | ||
plt.show() |