forked from charissar/ml-class
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import os | ||
os.environ["KERAS_BACKEND"] = "tensorflow" | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
|
||
from keras.layers import Input | ||
from keras.models import Model, Sequential | ||
from keras.layers.core import Reshape, Dense, Dropout, Flatten | ||
from keras.layers.advanced_activations import LeakyReLU | ||
from keras.layers.convolutional import Convolution2D, UpSampling2D | ||
from keras.layers.normalization import BatchNormalization | ||
from keras.datasets import mnist | ||
from keras.optimizers import Adam | ||
from keras import backend as K | ||
from keras import initializers | ||
from PIL import Image | ||
from keras.callbacks import LambdaCallback | ||
|
||
import wandb | ||
|
||
run = wandb.init() | ||
|
||
config = wandb.config | ||
|
||
|
||
# The results are a little better when the dimensionality of the random vector is only 10. | ||
# The dimensionality has been left at 100 for consistency with other GAN implementations. | ||
randomDim = 10 | ||
|
||
# Load MNIST data | ||
(X_train, y_train), (X_test, y_test) = mnist.load_data() | ||
X_train = (X_train.astype(np.float32) - 127.5)/127.5 | ||
X_train = X_train.reshape(60000, 784) | ||
|
||
config.lr=0.0002 | ||
config.beta_1=0.5 | ||
config.batch_size=128 | ||
config.epochs=10 | ||
|
||
# Optimizer | ||
adam = Adam(config.lr, beta_1=config.beta_1) | ||
|
||
generator = Sequential() | ||
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) | ||
generator.add(LeakyReLU(0.2)) | ||
generator.add(Dense(512)) | ||
generator.add(LeakyReLU(0.2)) | ||
generator.add(Dense(1024)) | ||
generator.add(LeakyReLU(0.2)) | ||
generator.add(Dense(784, activation='tanh')) | ||
generator.compile(loss='binary_crossentropy', optimizer=adam, metrics=['acc']) | ||
|
||
discriminator = Sequential() | ||
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02))) | ||
discriminator.add(LeakyReLU(0.2)) | ||
discriminator.add(Dropout(0.3)) | ||
discriminator.add(Dense(512)) | ||
discriminator.add(LeakyReLU(0.2)) | ||
discriminator.add(Dropout(0.3)) | ||
discriminator.add(Dense(256)) | ||
discriminator.add(LeakyReLU(0.2)) | ||
discriminator.add(Dropout(0.3)) | ||
discriminator.add(Dense(1, activation='sigmoid')) | ||
discriminator.compile(loss='binary_crossentropy', optimizer=adam, metrics=['acc']) | ||
|
||
# Combined network | ||
discriminator.trainable = False | ||
ganInput = Input(shape=(randomDim,)) | ||
x = generator(ganInput) | ||
ganOutput = discriminator(x) | ||
gan = Model(inputs=ganInput, outputs=ganOutput) | ||
gan.compile(loss='binary_crossentropy', optimizer=adam, metrics = ['acc']) | ||
|
||
dLosses = [] | ||
gLosses = [] | ||
iter = 0 | ||
|
||
# Write out generated MNIST images | ||
def writeGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)): | ||
noise = np.random.normal(0, 1, size=[examples, randomDim]) | ||
generatedImages = generator.predict(noise) | ||
generatedImages = generatedImages.reshape(examples, 28, 28) | ||
|
||
|
||
for i in range(10): | ||
img = Image.fromarray((generatedImages[0] + 1.)* (255/2.)) | ||
img = img.convert('RGB') | ||
img.save(str(i) + ".jpg") | ||
|
||
|
||
# Save the generator and discriminator networks (and weights) for later use | ||
def saveModels(epoch): | ||
generator.save('models/gan_generator_epoch_%d.h5' % epoch) | ||
discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch) | ||
|
||
|
||
def log_generator(epoch, logs): | ||
global iter | ||
iter += 1 | ||
if iter % 50 == 0: | ||
run.history.add({'generator_loss': logs['loss'], | ||
'generator_acc': logs['acc'], | ||
'discriminator_loss': 0.0, | ||
'discriminator_acc': (1-logs['acc'])}) | ||
|
||
def log_discriminator(epoch, logs): | ||
global iter | ||
if iter% 50 == 25: | ||
run.history.add({ | ||
'generator_loss': 0.0, | ||
'generator_acc': (logs['acc']), | ||
'discriminator_loss': logs['loss'], | ||
'discriminator_acc': logs['acc']}) | ||
|
||
def train(epochs=config.epochs, batchSize=config.batch_size): | ||
batchCount = int(X_train.shape[0] / config.batch_size) | ||
print('Epochs:', epochs) | ||
print('Batch size:', batchSize) | ||
print('Batches per epoch:', batchCount) | ||
|
||
wandb_logging_callback_d = LambdaCallback(on_epoch_end=log_discriminator) | ||
wandb_logging_callback_g = LambdaCallback(on_epoch_end=log_generator) | ||
|
||
|
||
for e in range(1, epochs+1): | ||
|
||
for i in range(batchCount): | ||
# Get a random set of input noise and images | ||
noise = np.random.normal(0, 1, size=[batchSize, randomDim]) | ||
imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)] | ||
|
||
# Generate fake MNIST images | ||
generatedImages = generator.predict(noise) | ||
# print np.shape(imageBatch), np.shape(generatedImages) | ||
X = np.concatenate([imageBatch, generatedImages]) | ||
|
||
# Labels for generated and real data | ||
yDis = np.zeros(2*batchSize) | ||
# One-sided label smoothing | ||
yDis[:batchSize] = 0.9 | ||
|
||
# Train discriminator | ||
discriminator.trainable = True | ||
dloss = discriminator.fit(X, yDis, callbacks = [wandb_logging_callback_d]) | ||
|
||
# Train generator | ||
noise = np.random.normal(0, 1, size=[batchSize, randomDim]) | ||
yGen = np.ones(batchSize) | ||
discriminator.trainable = False | ||
gloss = gan.fit(noise, yGen, callbacks = [wandb_logging_callback_g]) | ||
|
||
writeGeneratedImages(i) | ||
|
||
# Store loss of most recent batch from this epoch | ||
dLosses.append(dloss) | ||
gLosses.append(gloss) | ||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
train(200, 128) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[default] | ||
entity: ml-class | ||
project: gan | ||
base_url: https://api.wandb.ai |