forked from optuna/optuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tfkeras_integration.py
144 lines (100 loc) · 3.96 KB
/
tfkeras_integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Optuna example that demonstrates a pruner for tf.keras.
In this example, we optimize the validation accuracy of hand-written digit recognition
using tf.keras and MNIST, where the architecture of the neural network
and the parameters of optimizer are optimized.
Throughout the training of neural networks,
a pruner observes intermediate results and stops unpromising trials.
You can run this example as follows:
$ python tfkeras_integration.py
"""
import tensorflow as tf
import tensorflow_datasets as tfds
import optuna
from optuna.integration import TFKerasPruningCallback
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 20
N_TRAIN_EXAMPLES = 3000
STEPS_PER_EPOCH = int(N_TRAIN_EXAMPLES / BATCHSIZE / 10)
VALIDATION_STEPS = 30
def train_dataset():
ds = tfds.load('mnist', split=tfds.Split.TRAIN, shuffle_files=True)
ds = ds.map(lambda x: (tf.cast(x['image'], tf.float32)/255., x['label']))
ds = ds.repeat().shuffle(1024).batch(BATCHSIZE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def eval_dataset():
ds = tfds.load('mnist', split=tfds.Split.TEST, shuffle_files=False)
ds = ds.map(lambda x: (tf.cast(x['image'], tf.float32)/255., x['label']))
ds = ds.repeat().batch(BATCHSIZE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def create_model(trial):
# Hyperparameters to be tuned by Optuna.
lr = trial.suggest_loguniform('lr', 1e-4, 1e-1)
momentum = trial.suggest_uniform('momentum', 0.0, 1.0)
units = trial.suggest_categorical('units', [32, 64, 128, 256, 512])
# Compose neural network with one hidden layer.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=units, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(CLASSES, activation=tf.nn.softmax))
# Compile model.
model.compile(
optimizer=tf.keras.optimizers.SGD(lr=lr, momentum=momentum, nesterov=True),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
return model
def objective(trial):
# Clear clutter from previous TensorFlow graphs.
tf.keras.backend.clear_session()
# Metrics to be monitored by Optuna.
if tf.__version__ >= '2':
monitor = 'val_accuracy'
else:
monitor = 'val_acc'
# Create tf.keras model instance.
model = create_model(trial)
# Create dataset instance.
ds_train = train_dataset()
ds_eval = eval_dataset()
# Create callbacks for early stopping and pruning.
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=3),
TFKerasPruningCallback(trial, monitor)
]
# Train model.
history = model.fit(
ds_train,
epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=ds_eval,
validation_steps=VALIDATION_STEPS,
callbacks=callbacks,
)
# TODO(@sfujiwara): Investigate why the logger here is called twice.
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
# tf.compat.v1.logging.info('hello optuna')
return history.history[monitor][-1]
def show_result(study):
pruned_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.COMPLETE]
print('Study statistics: ')
print(' Number of finished trials: ', len(study.trials))
print(' Number of pruned trials: ', len(pruned_trials))
print(' Number of complete trials: ', len(complete_trials))
print('Best trial:')
trial = study.best_trial
print(' Value: ', trial.value)
print(' Params: ')
for key, value in trial.params.items():
print(' {}: {}'.format(key, value))
def main():
study = optuna.create_study(direction='maximize',
pruner=optuna.pruners.MedianPruner(n_startup_trials=2))
study.optimize(objective, n_trials=25, timeout=600)
show_result(study)
if __name__ == '__main__':
main()