Skip to content

Commit

Permalink
Flake8 fixes (tensorflow#320)
Browse files Browse the repository at this point in the history
Autopep8 + flake8
  • Loading branch information
sethtroisi authored Aug 10, 2018
1 parent 4d4ce89 commit 53a354d
Show file tree
Hide file tree
Showing 33 changed files with 78 additions and 108 deletions.
3 changes: 1 addition & 2 deletions dual_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from absl import flags
import argparse
import functools
import math
import os.path
import sys

Expand Down Expand Up @@ -214,7 +213,7 @@ def model_fn(features, labels, mode, params=None):
tf.square(value_output - labels['value_tensor']))

reg_vars = [v for v in tf.trainable_variables()
if not 'bias' in v.name and not 'beta' in v.name]
if 'bias' not in v.name and 'beta' not in v.name]
l2_cost = FLAGS.l2_strength * \
tf.add_n([tf.nn.l2_loss(v) for v in reg_vars])

Expand Down
30 changes: 16 additions & 14 deletions example_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import multiprocessing as mp
import os
import random
import re
import subprocess
import time
from collections import deque
Expand Down Expand Up @@ -94,7 +93,7 @@ def update(self, new_games):
self.total_updates += num_new_games
self.examples.extend(self.func(game))
if first_new_game is None:
print ("No new games", file_timestamp(new_games[-1]), self.examples[-1][0])
print("No new games", file_timestamp(new_games[-1]), self.examples[-1][0])

def flush(self, path):
# random.shuffle on deque is O(n^2) convert to list for O(n)
Expand Down Expand Up @@ -122,6 +121,7 @@ def __str__(self):
def files_for_model(model):
return tf.gfile.Glob(os.path.join(LOCAL_DIR, model[1], '*.zz'))


def smart_rsync(
from_model_num=0,
source_dir=None,
Expand All @@ -133,6 +133,7 @@ def smart_rsync(
_rsync_dir(os.path.join(
source_dir, model), os.path.join(dest_dir, model))


def time_rsync(from_date,
source_dir=None,
dest_dir=LOCAL_DIR):
Expand Down Expand Up @@ -178,6 +179,7 @@ def _determine_chunk_to_make(write_dir):

return chunk_to_make, False


def get_window_size(chunk_num):
""" Adjust the window size by how far we are through a run.
At the start of the run, there's a benefit to 'expiring' the completely
Expand All @@ -186,10 +188,11 @@ def get_window_size(chunk_num):
"""
return min(500000, (chunk_num + 5) * (AVG_GAMES_PER_MODEL // 2))


def fill_and_wait_time(bufsize=dual_net.EXAMPLES_PER_GENERATION,
write_dir=None,
threads=32,
start_from=None):
write_dir=None,
threads=32,
start_from=None):
start_from = start_from or dt.datetime.utcnow()
write_dir = write_dir or fsdb.golden_chunk_dir()
buf = ExampleBuffer(bufsize)
Expand All @@ -202,8 +205,7 @@ def fill_and_wait_time(bufsize=dual_net.EXAMPLES_PER_GENERATION,

hours = fsdb.get_hour_dirs()
files = (tf.gfile.Glob(os.path.join(LOCAL_DIR, d, "*.zz"))
for d in reversed(hours)
if tf.gfile.Exists(os.path.join(LOCAL_DIR, d)))
for d in reversed(hours) if tf.gfile.Exists(os.path.join(LOCAL_DIR, d)))
files = itertools.islice(files, get_window_size(chunk_to_make))

models = fsdb.get_models()
Expand All @@ -216,32 +218,32 @@ def fill_and_wait_time(bufsize=dual_net.EXAMPLES_PER_GENERATION,
start_from = dt.datetime.utcnow()
hours = sorted(fsdb.get_hour_dirs(LOCAL_DIR))
new_files = list(map(lambda d: tf.gfile.Glob(
os.path.join(LOCAL_DIR, d, '*.zz')), hours[-2:]))
os.path.join(LOCAL_DIR, d, '*.zz')), hours[-2:]))
buf.update(list(itertools.chain.from_iterable(new_files)))
if fast_write:
break
time.sleep(30)
if fsdb.get_latest_model() != models[-1]:
print ("New model! Waiting for games. Got", buf.total_updates, "new games so far")
print("New model! Waiting for games. Got", buf.total_updates, "new games so far")

latest = fsdb.get_latest_model()
print("New model!", latest[1], "!=", models[-1][1])
print(buf)
buf.flush(chunk_to_make)


def fill_and_wait_models(bufsize=dual_net.EXAMPLES_PER_GENERATION,
write_dir=None,
threads=8,
model_window=100,
skip_first_rsync=False):
write_dir=None,
threads=8,
model_window=100,
skip_first_rsync=False):
""" Fills a ringbuffer with positions from the most recent games, then
continually rsync's and updates the buffer until a new model is promoted.
Once it detects a new model, iit then dumps its contents for training to
immediately begin on the next model.
"""
write_dir = write_dir or fsdb.golden_chunk_dir()
buf = ExampleBuffer(bufsize)
chunk_to_make = _determine_chunk_to_make(write_dir)
models = fsdb.get_models()[-model_window:]
if not skip_first_rsync:
with timer("Rsync"):
Expand Down
4 changes: 3 additions & 1 deletion fsdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@

FLAGS = flags.FLAGS


def _with_base(*args):
def inner():
base_dir = FLAGS.base_dir or 'gs://{}'.format(FLAGS.bucket_name)
return os.path.join(base_dir, *args)
return inner


# Functions to compute various important directories, based on FLAGS input.
models_dir = _with_base('models')
selfplay_dir = _with_base('data', 'selfplay')
Expand Down Expand Up @@ -84,6 +86,7 @@ def get_model(model_num):
model_names_by_num = dict(get_models())
return model_names_by_num[model_num]


def get_hour_dirs(root=None):
"""Gets the directories under selfplay_dir that match YYYY-MM-DD-HH."""
root = root or selfplay_dir()
Expand All @@ -96,4 +99,3 @@ def game_counts(n_back=20):
for _, model_name in get_models[-n_back:]:
games = gfile.Glob(os.path.join(selfplay_dir(), model_name, '*.zz'))
print(model_name, len(games))

9 changes: 5 additions & 4 deletions go.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def find_reached(board, c):
current = frontier.pop()
chain.add(current)
for n in NEIGHBORS[current]:
if board[n] == color and not n in chain:
if board[n] == color and n not in chain:
frontier.append(n)
elif board[n] != color:
reached.add(n)
Expand All @@ -108,7 +108,7 @@ def is_koish(board, c):
if board[c] != EMPTY:
return None
neighbors = {board[n] for n in NEIGHBORS[c]}
if len(neighbors) == 1 and not EMPTY in neighbors:
if len(neighbors) == 1 and EMPTY not in neighbors:
return list(neighbors)[0]
else:
return None
Expand Down Expand Up @@ -303,7 +303,8 @@ def __init__(self, board=None, n=0, komi=7.5, caps=(0, 0),
'''
assert type(recent) is tuple
self.board = board if board is not None else np.copy(EMPTY_BOARD)
self.n = n # With a full history, self.n == len(self.recent) == num moves played
# With a full history, self.n == len(self.recent) == num moves played
self.n = n
self.komi = komi
self.caps = caps
self.lib_tracker = lib_tracker or LibertyTracker.from_board(self.board)
Expand Down Expand Up @@ -401,7 +402,7 @@ def all_legal_moves(self):
legal_moves[self.board != EMPTY] = 0
# calculate which spots have 4 stones next to them
# padding is because the edge always counts as a lost liberty.
adjacent = np.ones([N+2, N+2], dtype=np.int8)
adjacent = np.ones([N + 2, N + 2], dtype=np.int8)
adjacent[1:-1, 1:-1] = np.abs(self.board)
num_adjacent_stones = (adjacent[:-2, 1:-1] + adjacent[1:-1, :-2] +
adjacent[2:, 1:-1] + adjacent[1:-1, 2:])
Expand Down
3 changes: 3 additions & 0 deletions gtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

FLAGS = flags.FLAGS


def make_gtp_instance(load_file, cgos_mode=False, kgs_mode=False, verbosity=1,
num_readouts=None):
'''Takes a path to model files and set up a GTP engine instance.'''
Expand All @@ -65,6 +66,7 @@ def make_gtp_instance(load_file, cgos_mode=False, kgs_mode=False, verbosity=1,

return engine


def main(argv):
'''Run Minigo in GTP mode.'''
del argv
Expand All @@ -77,5 +79,6 @@ def main(argv):
if not engine.handle_msg(msg.strip()):
break


if __name__ == '__main__':
app.run(main)
4 changes: 2 additions & 2 deletions gtp_cmd_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def cmd_boardsize(self, n: int):

def cmd_clear_board(self):
position = self._player.get_position()
if (self._player.get_result_string()
and position and len(position.recent) > 1):
if (self._player.get_result_string() and
position and len(position.recent) > 1):
try:
sgf = self._player.to_sgf()
with open(datetime.now().strftime("%Y-%m-%d-%H:%M.sgf"), 'w') as f:
Expand Down
6 changes: 1 addition & 5 deletions inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
import sys
import time
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.training import saver
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
import threading
import numpy as np
from absl import flags
Expand Down Expand Up @@ -131,7 +128,7 @@ def const_model_inference_fn(features):
def custom_getter(getter, name, *args, **kwargs):
with tf.control_dependencies(None):
return tf.guarantee_const(
getter(name, *args, **kwargs), name=name+"/GuaranteeConst")
getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
with tf.variable_scope("", custom_getter=custom_getter):
return dual_net.model_inference_fn(features, False)

Expand Down Expand Up @@ -312,7 +309,6 @@ def _prepare_features(self, raw_features):
features = []
for i in range(self._parallel_tpus):
begin = i * num_features
end = begin + num_features
x = np.frombuffer(
raw_features, dtype=np.int8, count=num_features, offset=begin)
x = x.reshape([self._batch_size, go.N, go.N,
Expand Down
5 changes: 1 addition & 4 deletions local_rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@
"""

import os
import sys
import tempfile

from absl import flags
import preprocessing
import dual_net
import go
import main
import selfplay
import example_buffer as eb
from tensorflow import gfile
import subprocess


Expand Down Expand Up @@ -117,6 +113,7 @@ def rl_loop():
sgf_dir=sgf_dir,
holdout_pct=0)


if __name__ == '__main__':
# horrible horrible hack to pass flag validation.
# Problems come from local_rl_loop calling into main() as library calls
Expand Down
9 changes: 1 addition & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@
import argh
import argparse
import os.path
import random
import socket
import sys
import tempfile
import time

import dual_net
import evaluation
import preprocessing
import utils

import cloud_logging
import tensorflow as tf
from absl import flags
from tqdm import tqdm
from tensorflow import gfile

# How many positions we should aggregate per 'chunk'.
Expand Down Expand Up @@ -67,7 +62,7 @@ def train_dir(


def train(tf_records: 'list of files of tf_records to train on',
model_save_path: 'Where to export the completed generation.'):
model_save_path: 'Where to export the completed generation.'):
print("Training on:", tf_records[0], "to", tf_records[-1])
with utils.logged_timer("Training"):
dual_net.train(*tf_records)
Expand Down Expand Up @@ -107,8 +102,6 @@ def evaluate(
black_net, white_net, games, output_dir, verbose)




def convert(load_file, dest_file):
from tensorflow.python.framework import meta_graph
features, labels = dual_net.get_inference_input()
Expand Down
11 changes: 5 additions & 6 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ def __repr__(self):

@property
def child_action_score(self):
return (self.child_Q * self.position.to_play
+ self.child_U
- 1000 * self.illegal_moves)
return (self.child_Q * self.position.to_play +
self.child_U - 1000 * self.illegal_moves)

@property
def child_Q(self):
Expand Down Expand Up @@ -142,9 +141,9 @@ def select_leaf(self):
break
# HACK: if last move was a pass, always investigate double-pass first
# to avoid situations where we auto-lose by passing too early.
if (current.position.recent
and current.position.recent[-1].move is None
and current.child_N[pass_move] == 0):
if (current.position.recent and
current.position.recent[-1].move is None and
current.child_N[pass_move] == 0):
current = current.maybe_add_child(pass_move)
continue

Expand Down
3 changes: 1 addition & 2 deletions minigui/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import json
import logging
import subprocess
from threading import Lock

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -150,4 +149,4 @@ def index():


if __name__ == "__main__":
socketio.run(app, port=args.port, host=args.host)
socketio.run(app, port=args.port, host=args.host)
2 changes: 1 addition & 1 deletion oneoffs/compare_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def ReadExamples(path):


def main(unused_argv):
with tf.Session() as sess:
with tf.Session() as _:
examples_a = ReadExamples(FLAGS.a)
examples_b = ReadExamples(FLAGS.b)
print(len(examples_a), len(examples_b))
Expand Down
2 changes: 0 additions & 2 deletions oneoffs/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import sys
sys.path.insert(0, '.')

import itertools
import os

import numpy as np
import tensorflow as tf
from absl import app, flags
from tqdm import tqdm
Expand Down
2 changes: 1 addition & 1 deletion oneoffs/l2_cost_by_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def dual_net_list(model):
print("Dual Net will calculate L2 cost over these variables")
with dual.sess.graph.as_default():
var_names = [v.name for v in tf.trainable_variables()]
_ = reduce_and_print_vars(var_names)
reduce_and_print_vars(var_names)
print()


Expand Down
Loading

0 comments on commit 53a354d

Please sign in to comment.