Skip to content

Commit 2f9eaa3

Browse files
committed
fix compatibility of tensorflow
1 parent 8cce5b8 commit 2f9eaa3

4 files changed

+22
-4
lines changed

kerasTUT/6-CNN_example.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
# 6 - CNN example
1212

13+
# to try tensorflow, un-comment following two lines
14+
# import os
15+
# os.environ['KERAS_BACKEND']='tensorflow'
16+
1317
import numpy as np
1418
np.random.seed(1337) # for reproducibility
1519
from keras.datasets import mnist
@@ -23,8 +27,8 @@
2327
(X_train, y_train), (X_test, y_test) = mnist.load_data()
2428

2529
# data pre-processing
26-
X_train = X_train.reshape(-1, 1, 28, 28)
27-
X_test = X_test.reshape(-1, 1, 28, 28)
30+
X_train = X_train.reshape(-1, 1,28, 28)
31+
X_test = X_test.reshape(-1, 1,28, 28)
2832
y_train = np_utils.to_categorical(y_train, nb_classes=10)
2933
y_test = np_utils.to_categorical(y_test, nb_classes=10)
3034

@@ -37,8 +41,9 @@
3741
nb_row=5,
3842
nb_col=5,
3943
border_mode='same', # Padding method
44+
dim_ordering='th', # if use tensorflow, to set the input dimension order to theano ("th") style, but you can change it.
4045
input_shape=(1, # channels
41-
28, 28) # height & width
46+
28, 28,) # height & width
4247
))
4348
model.add(Activation('relu'))
4449

kerasTUT/7-RNN_Classifier_example.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
# 8 - RNN Classifier example
1212

13+
# to try tensorflow, un-comment following two lines
14+
# import os
15+
# os.environ['KERAS_BACKEND']='tensorflow'
16+
1317
import numpy as np
1418
np.random.seed(1337) # for reproducibility
1519

@@ -43,8 +47,11 @@
4347

4448
# RNN cell
4549
model.add(SimpleRNN(
46-
batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_SIZE), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
50+
# for batch_input_shape, if using tensorflow as the backend, we have to put None for the batch_size.
51+
# Otherwise, model.evaluate() will get error.
52+
batch_input_shape=(None, TIME_STEPS, INPUT_SIZE), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
4753
output_dim=CELL_SIZE,
54+
unroll=True,
4855
))
4956

5057
# output layer

kerasTUT/8-RNN_LSTM_Regressor_example.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
# 8 - RNN LSTM Regressor example
1212

13+
# to try tensorflow, un-comment following two lines
14+
# import os
15+
# os.environ['KERAS_BACKEND']='tensorflow'
1316
import numpy as np
1417
np.random.seed(1337) # for reproducibility
1518
import matplotlib.pyplot as plt

kerasTUT/9-Autoencoder_example.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
# 9 - Autoencoder example
1212

13+
# to try tensorflow, un-comment following two lines
14+
# import os
15+
# os.environ['KERAS_BACKEND']='tensorflow'
1316
import numpy as np
1417
np.random.seed(1337) # for reproducibility
1518

0 commit comments

Comments
 (0)