Skip to content

Commit

Permalink
Don't divide steps by NUM_TPUS if passed from command line (tensorflo…
Browse files Browse the repository at this point in the history
  • Loading branch information
sethtroisi authored Aug 3, 2018
1 parent d7b3486 commit 4d4ce89
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions dual_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
# pi_tensor is one_hot when generated from sgfs (for supervised learning)
# and soft-max when using self-play records. argmax normalizes the two.
policy_target_top_1 = tf.argmax(pi_tensor, axis=1)
policy_output_top_1 = tf.argmax(policy_output, axis=1)

policy_output_in_top1 = tf.to_float(
tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
Expand All @@ -251,7 +250,6 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
policy_output,
tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

# TODO(sethtroisi): For V10 add tf.variable_scope for tf.metrics.mean's
with tf.variable_scope("metrics"):
metric_ops = {
'policy_cost': tf.metrics.mean(policy_cost),
Expand Down Expand Up @@ -485,6 +483,8 @@ def count_examples(tf_record):

total_examples = sum(map(count_examples, tf_records))
steps = total_examples // FLAGS.train_batch_size
if FLAGS.use_tpu:
steps //= FLAGS.num_tpu_cores

if FLAGS.use_tpu:
def input_fn(params):
Expand All @@ -494,7 +494,6 @@ def input_fn(params):
random_rotation=True)
# TODO: get hooks working again with TPUestimator.
hooks = []
steps //= FLAGS.num_tpu_cores
else:
def input_fn():
return preprocessing.get_input_tensors(
Expand Down

0 comments on commit 4d4ce89

Please sign in to comment.