Skip to content

Commit

Permalink
Create vae_keras_celeba.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored Sep 1, 2018
1 parent 36d94b8 commit 08a51b7
Showing 1 changed file with 149 additions and 0 deletions.
149 changes: 149 additions & 0 deletions vae_keras_celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#! -*- coding: utf-8 -*-

import numpy as np
from scipy import misc
import glob
import imageio
from keras.models import Model
from keras.layers import *
from keras import backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback


imgs = glob.glob('img_align_celeba/*.jpg')
np.random.shuffle(imgs)

height,width = misc.imread(imgs[0]).shape[:2]
center_height = int((height - width) / 2)
img_dim = 64
z_dim = 512


def imread(f):
x = misc.imread(f)
x = x[center_height:center_height+width, :]
x = misc.imresize(x, (img_dim, img_dim))
return x.astype(np.float32) / 255 * 2 - 1


def data_generator(batch_size=32):
X = []
while True:
np.random.shuffle(imgs)
for f in imgs:
X.append(imread(f))
if len(X) == batch_size:
X = np.array(X)
yield X,None
X = []


x_in = Input(shape=(img_dim, img_dim, 3))
x = x_in
x = Conv2D(z_dim/16, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/8, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/4, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/2, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = GlobalAveragePooling2D()(x)

encoder = Model(x_in, x)
encoder.summary()
map_size = K.int_shape(encoder.layers[-2].output)[1:-1]

# 解码层,也就是生成器部分
z_in = Input(shape=K.int_shape(x)[1:])
z = z_in
z = Dense(np.prod(map_size)*z_dim)(z)
z = Reshape(map_size + (z_dim,))(z)
z = Conv2DTranspose(z_dim/2, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/4, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/8, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/16, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(3, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = Activation('tanh')(z)

decoder = Model(z_in, z)
decoder.summary()

class ScaleShift(Layer):
def __init__(self, **kwargs):
super(ScaleShift, self).__init__(**kwargs)
def call(self, inputs):
z, shift, log_scale = inputs
z = K.exp(log_scale) * z + shift
logdet = -K.sum(K.mean(log_scale, 0))
self.add_loss(logdet)
return z

z_shift = Dense(z_dim)(x)
z_log_scale = Dense(z_dim)(x)
u = Lambda(lambda z: K.random_normal(shape=K.shape(z)))(z_shift)
z = ScaleShift()([u, z_shift, z_log_scale])

x_recon = decoder(z)
x_out = Subtract()([x_in, x_recon])

recon_loss = 0.5 * K.sum(K.mean(x_out**2, 0)) + 0.5 * np.log(2*np.pi) * np.prod(K.int_shape(x_out)[1:])
z_loss = 0.5 * K.sum(K.mean(z**2, 0)) - 0.5 * K.sum(K.mean(u**2, 0))
vae_loss = recon_loss + z_loss

vae = Model(x_in, x_out)
vae.add_loss(vae_loss)
vae.compile(optimizer=Adam(1e-4))


def sample(path):
n = 9
figure = np.zeros((img_dim*n, img_dim*n, 3))
for i in range(n):
for j in range(n):
x_recon = decoder.predict(np.random.randn(1, *K.int_shape(x)[1:]))
digit = x_recon[0]
figure[i*img_dim: (i+1)*img_dim,
j*img_dim: (j+1)*img_dim] = digit
figure = (figure + 1) / 2 * 255
imageio.imwrite(path, figure)


class Evaluate(Callback):
def __init__(self):
import os
self.lowest = 1e10
self.losses = []
if not os.path.exists('samples'):
os.mkdir('samples')
def on_epoch_end(self, epoch, logs=None):
path = 'samples/test_%s.png' % epoch
sample(path)
self.losses.append((epoch, logs['loss']))
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
encoder.save_weights('./best_encoder.weights')


evaluator = Evaluate()

vae.fit_generator(data_generator(),
epochs=1000,
steps_per_epoch=1000,
callbacks=[evaluator])

0 comments on commit 08a51b7

Please sign in to comment.