diff --git a/model.py b/model.py index 55ccc0b..f175f5c 100644 --- a/model.py +++ b/model.py @@ -22,8 +22,8 @@ tf.logging.set_verbosity(tf.logging.INFO) # Training Parameters -learning_rate_old = 1e-10 -num_steps = 9000 +start_lr = 1e-10 +training_steps = 12000 batch_size = 128 @@ -95,6 +95,7 @@ def model_fn(features, labels, mode): logits=logits_train, labels=tf.cast(labels, dtype=tf.int32))) train_ops = None + acc_op = None # -------------------------------------------------------------------------- # Optimize @@ -102,7 +103,7 @@ def model_fn(features, labels, mode): if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() learning_rate = tf.train.exponential_decay( - learning_rate_old, global_step=global_step, + start_lr, global_step=global_step, decay_steps=100, decay_rate=1.30) optimizer = tf.train.AdamOptimizer(learning_rate) train_ops = optimizer.minimize(loss, global_step=global_step) @@ -110,11 +111,12 @@ def model_fn(features, labels, mode): tf.summary.scalar("current_step", global_step) tf.summary.scalar("loss", loss) - - # Evaluate the accuracy of the model - acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes) - + # Evaluate the accuracy of the model + if mode == tf.estimator.ModeKeys.EVAL: + acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes) + return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops={'accuracy': acc_op}) + # TF Estimators requires to return a EstimatorSpec, that specify # the different ops for training, evaluating, ... @@ -122,8 +124,7 @@ def model_fn(features, labels, mode): mode=mode, predictions=pred_classes, loss=loss, - train_op=train_ops, - eval_metric_ops={'accuracy': acc_op}) + train_op=train_ops) return estim_specs @@ -136,7 +137,7 @@ def model_fn(features, labels, mode): batch_size=batch_size, num_epochs=None, shuffle=True) # Train the Model -model.train(input_fn, steps=num_steps) +model.train(input_fn, steps=training_steps) # Evaluate the Model # Define the input function for evaluating @@ -146,4 +147,6 @@ def model_fn(features, labels, mode): # Use the Estimator 'evaluate' method e = model.evaluate(input_fn) +print("The test accuracy of the network: ", e['accuracy']) +