This repository was archived by the owner on Mar 11, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 562
/
Copy pathtrain_and_validate.py
154 lines (128 loc) · 5.04 KB
/
train_and_validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Run train and validate in a loop, as subprocesses.
We run as subprocesses because it gives us some isolation.
"""
import itertools
import os
import sys
import time
sys.path.insert(0, '.')
from absl import app, flags
from tensorflow import gfile
import tensorflow as tf
from rl_loop import fsdb
import mask_flags
from rl_loop import shipname
import utils
import dual_net
flags.DEFINE_string('pro_dataset', None,
'Location of preprocessed pro dataset for validation')
# From fsdb.py - must pass one of the two.
flags.declare_key_flag('base_dir')
flags.declare_key_flag('bucket_name')
FLAGS = flags.FLAGS
try:
if 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ:
TPU_NAME = os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
else:
TPU_NAME = os.environ['TPU_NAME']
except KeyError:
raise Exception("Must have $TPU_NAME configured")
def train():
model_num, model_name = fsdb.get_latest_model()
print("Training on gathered game data, initializing from {}".format(
model_name))
new_model_num = model_num + 1
new_model_name = shipname.generate(new_model_num)
print("New model will be {}".format(new_model_name))
save_file = os.path.join(fsdb.models_dir(), new_model_name)
# TODO(jacksona): Refactor train.py to take the filepath as a flag.
cmd = ['python3', 'train.py', '__unused_file__',
'--use_tpu',
'--use_bt',
'--work_dir={}'.format(fsdb.working_dir()),
'--tpu_name={}'.format(TPU_NAME),
'--flagfile=rl_loop/distributed_flags',
'--export_path={}'.format(save_file)]
completed_process = mask_flags.run(cmd)
if completed_process.returncode > 0:
print("Training failed!")
return completed_process
# Train.py already copies the {data,index,meta} files to $BUCKET/models
# Persist the checkpoint two ways:
# Freeze the .ckpt file in the work_dir for the TPU selfplayers
# Freeze a non-tpu version of the graph for later GPU use.
latest_checkpoint = tf.train.latest_checkpoint(fsdb.working_dir())
p = freeze(latest_checkpoint, rewrite_tpu=True)
if p.returncode > 0:
print("== TPU freeze failed!")
return p
p = freeze(save_file, rewrite_tpu=False)
if p.returncode > 0:
print("== Model freeze failed!")
return p
return completed_process
def freeze(save_path, rewrite_tpu=False):
cmd = ['python3', 'freeze_graph.py',
'--work_dir={}'.format(fsdb.working_dir()),
'--flagfile=rl_loop/distributed_flags',
'--model_path={}'.format(save_path)]
if rewrite_tpu:
cmd.extend(['--use_tpu',
'--tpu_name={}'.format(TPU_NAME)])
return mask_flags.run(cmd)
def validate_holdout_selfplay():
"""Validate on held-out selfplay data."""
holdout_dirs = (os.path.join(fsdb.holdout_dir(), d)
for d in reversed(gfile.ListDirectory(fsdb.holdout_dir()))
if gfile.IsDirectory(os.path.join(fsdb.holdout_dir(), d))
for f in gfile.ListDirectory(os.path.join(fsdb.holdout_dir(), d)))
# This is a roundabout way of computing how many hourly directories we need
# to read in order to encompass 20,000 holdout games.
holdout_dirs = set(itertools.islice(holdout_dirs), 20000)
cmd = ['python3', 'validate.py'] + list(holdout_dirs) + [
'--use_tpu',
'--tpu_name={}'.format(TPU_NAME),
'--flagfile=rl_loop/distributed_flags',
'--expand_validation_dirs']
mask_flags.run(cmd)
def validate_pro():
"""Validate on professional data."""
cmd = ['python3', 'validate.py', FLAGS.pro_dataset,
'--use_tpu',
'--tpu_name={}'.format(TPU_NAME),
'--work_dir={}'.format(fsdb.working_dir()),
'--flagfile=rl_loop/distributed_flags',
'--validate_name=pro']
mask_flags.run(cmd)
def loop(unused_argv):
if len(fsdb.get_models()) == 0:
# TODO(amj): Do bootstrap here.
pass
while True:
print("=" * 40, flush=True)
with utils.timer("Train"):
completed_process = train()
if completed_process.returncode > 0:
print("Training failed, aborting.")
sys.exit(1)
with utils.timer("Validate"):
if not FLAGS.pro_dataset:
print("*** --pro_dataset not set, skipping pro validation ***")
else:
validate_pro()
#validate_holdout_selfplay()
if __name__ == '__main__':
app.run(loop)