diff --git a/vae_keras.py b/vae_keras.py index befed63..7efab13 100644 --- a/vae_keras.py +++ b/vae_keras.py @@ -24,8 +24,7 @@ latent_dim = 2 # 隐变量取2维只是为了方便后面画图 intermediate_dim = 256 epochs = 50 -epsilon_std = 1.0 -num_classes = 10 + # 加载MNIST数据集 (x_train, y_train_), (x_test, y_test_) = mnist.load_data() @@ -45,8 +44,7 @@ # 重参数技巧 def sampling(args): z_mean, z_log_var = args - epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., - stddev=epsilon_std) + epsilon = K.random_normal(shape=K.shape(z_mean)) return z_mean + K.exp(z_log_var / 2) * epsilon # 重参数层,相当于给输入加入噪声