diff --git a/dual_net.py b/dual_net.py index 7a405e584..731103320 100644 --- a/dual_net.py +++ b/dual_net.py @@ -239,7 +239,6 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos # pi_tensor is one_hot when generated from sgfs (for supervised learning) # and soft-max when using self-play records. argmax normalizes the two. policy_target_top_1 = tf.argmax(pi_tensor, axis=1) - policy_output_top_1 = tf.argmax(policy_output, axis=1) policy_output_in_top1 = tf.to_float( tf.nn.in_top_k(policy_output, policy_target_top_1, k=1)) @@ -251,7 +250,6 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos policy_output, tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1])) - # TODO(sethtroisi): For V10 add tf.variable_scope for tf.metrics.mean's with tf.variable_scope("metrics"): metric_ops = { 'policy_cost': tf.metrics.mean(policy_cost), @@ -485,6 +483,8 @@ def count_examples(tf_record): total_examples = sum(map(count_examples, tf_records)) steps = total_examples // FLAGS.train_batch_size + if FLAGS.use_tpu: + steps //= FLAGS.num_tpu_cores if FLAGS.use_tpu: def input_fn(params): @@ -494,7 +494,6 @@ def input_fn(params): random_rotation=True) # TODO: get hooks working again with TPUestimator. hooks = [] - steps //= FLAGS.num_tpu_cores else: def input_fn(): return preprocessing.get_input_tensors(