Skip to content

Commit

Permalink
Implement stochastic weight averaging.
Browse files Browse the repository at this point in the history
* Implement stochastic weight averaging.
* Recalculate SWA batch norm.
* Don't run train_op when recalculating BN. It modifies the momentum
  state for the training run.
* Save and restore support.
* Add option to limit the maximum networks to average.
* Use lower number of steps in BN recalculation. When it's done for
  every output network it should already be almost correct.

Pull request leela-zero#1064.
  • Loading branch information
Ttl authored and gcp committed Mar 23, 2018
1 parent 58ee56f commit 94f48bb
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 36 deletions.
74 changes: 74 additions & 0 deletions training/tf/average_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python3
#
# This file is part of Leela Zero.
# Copyright (C) 2017 Henrik Forsten
#
# Leela Zero is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Leela Zero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Leela Zero. If not, see <http://www.gnu.org/licenses/>.

import argparse
import numpy as np

def swa(inputs, output, weights=None):
""" Average weights of the weight files.
inputs : List of filenames to use as inputs
output : String of output filename
weights : List of numbers to use for weighting the inputs
"""

out_weights = []

if weights == None:
weights = [1.0]*len(inputs)

if len(weights) != len(inputs):
raise ValueError("Number of weights doesn't match number of input files")

# Normalize weights
weights = [float(w)/sum(weights) for w in weights]

for count, filename in enumerate(inputs):
with open(filename, 'r') as f:
weights_in = []
for line in f:
weights_in.append(weights[count] * np.array(list(map(float, line.split(' ')))))
if count == 0:
out_weights = weights_in
else:
if len(out_weights) != len(weights_in):
raise ValueError("Nets have different sizes")
for e, w in enumerate(weights_in):
if len(w) != len(out_weights[e]):
raise ValueError("Nets have different sizes")
out_weights[e] += w

with open(output, 'w') as f:
for e, w in enumerate(out_weights):
if e == 0:
#Version
f.write('1\n')
else:
f.write(' '.join(map(str, w)) + '\n')

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Average weight files.')
parser.add_argument('-i', '--inputs', nargs='+',
help='List of input weight files')
parser.add_argument('-w', '--weights', type=float, nargs='+',
help='List of weights to use for the each weight file during averaging.')
parser.add_argument('-o', '--output', help='Output filename')

args = parser.parse_args()

swa(args.inputs, args.output, args.weights)
60 changes: 26 additions & 34 deletions training/tf/net_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,33 @@
import tensorflow as tf
import os
import sys
from tfprocess import TFProcess
from tfprocess import TFProcess, read_weights

with open(sys.argv[1], 'r') as f:
weights = []
for e, line in enumerate(f):
if e == 0:
#Version
print("Version", line.strip())
if line != '1\n':
raise ValueError("Unknown version {}".format(line.strip()))
else:
weights.append(list(map(float, line.split(' '))))
if e == 2:
channels = len(line.split(' '))
print("Channels", channels)
blocks = e - (4 + 14)
if blocks % 8 != 0:
raise ValueError("Inconsistent number of weights in the file")
blocks //= 8

if __name__ == "__main__":
version, blocks, channels, weights = read_weights(sys.argv[1])

if version == None:
raise ValueError("Unable to read version number")

print("Version", version)
print("Channels", channels)
print("Blocks", blocks)

x = [
tf.placeholder(tf.float32, [None, 18, 19 * 19]),
tf.placeholder(tf.float32, [None, 362]),
tf.placeholder(tf.float32, [None, 1])
]
x = [
tf.placeholder(tf.float32, [None, 18, 19 * 19]),
tf.placeholder(tf.float32, [None, 362]),
tf.placeholder(tf.float32, [None, 1])
]

tfprocess = TFProcess()
tfprocess.init_net(x)
if tfprocess.RESIDUAL_BLOCKS != blocks:
raise ValueError("Number of blocks in tensorflow model doesn't match "\
"number of blocks in input network")
if tfprocess.RESIDUAL_FILTERS != channels:
raise ValueError("Number of filters in tensorflow model doesn't match "\
"number of filters in input network")
tfprocess.replace_weights(weights)
path = os.path.join(os.getcwd(), "leelaz-model")
save_path = tfprocess.saver.save(tfprocess.session, path, global_step=0)
tfprocess = TFProcess()
tfprocess.init_net(x)
if tfprocess.RESIDUAL_BLOCKS != blocks:
raise ValueError("Number of blocks in tensorflow model doesn't match "\
"number of blocks in input network")
if tfprocess.RESIDUAL_FILTERS != channels:
raise ValueError("Number of filters in tensorflow model doesn't match "\
"number of filters in input network")
tfprocess.replace_weights(weights)
path = os.path.join(os.getcwd(), "leelaz-model")
save_path = tfprocess.saver.save(tfprocess.session, path, global_step=0)
93 changes: 91 additions & 2 deletions training/tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import numpy as np
import time
import tensorflow as tf
from shutil import copyfile
from average_weights import swa

def weight_variable(shape):
"""Xavier initialization"""
Expand All @@ -46,6 +48,28 @@ def conv2d(x, W):
return tf.nn.conv2d(x, W, data_format='NCHW',
strides=[1, 1, 1, 1], padding='SAME')

def read_weights(filename):
""" Read weights from file to array """
weights = []
version = None

with open(filename, 'r') as f:
for e, line in enumerate(f):
if e == 0:
#Version
version = int(line.strip())
if version != 1:
raise ValueError("Unknown version {}".format(line.strip()))
else:
weights.append(list(map(float, line.split(' '))))
if e == 2:
channels = len(line.split(' '))
blocks = e - (4 + 14)
if blocks % 8 != 0:
raise ValueError("Inconsistent number of weights in the file")
blocks //= 8
return version, blocks, channels, weights

class TFProcess:
def __init__(self):
# Network structure
Expand Down Expand Up @@ -79,6 +103,26 @@ def init_net(self, next_batch):
self.batch_norm_count = 0
self.y_conv, self.z_conv = self.construct_net(self.x)

# Output weight file with averaged weights
self.swa_enabled = True

# Nets to skip
# Output net number n is used for averaging if n % c == 0
self.swa_c = 1

# Maximum number of nets to average
# Set to None to disable the limit
self.swa_max_n = 16

# Filename for initial averaged network
self.prev_swa = tf.Variable('', trainable=False)

# Recalculate SWA weight batchnorm means and variances
self.swa_recalc_bn = True

# Nets written to disk
self.output_nets = tf.Variable(0, trainable=False)

# Calculate loss on policy head
cross_entropy = \
tf.nn.softmax_cross_entropy_with_logits(labels=self.y_,
Expand Down Expand Up @@ -233,13 +277,22 @@ def process(self, batch_size):
self.test_writer.add_summary(test_summaries, steps)
print("step {}, policy={:g} training accuracy={:g}%, mse={:g}".\
format(steps, sum_policy, sum_accuracy*100.0, sum_mse))

path = os.path.join(os.getcwd(), "leelaz-model")
save_path = self.saver.save(self.session, path, global_step=steps)
print("Model saved in file: {}".format(save_path))
leela_path = path + "-" + str(steps) + ".txt"
self.save_leelaz_weights(leela_path)
print("Leela weights saved to {}".format(leela_path))

prev_swa, output_nets = self.session.run([self.prev_swa, self.output_nets])
if self.swa_enabled and output_nets % self.swa_c == 0:
self.save_swa_network(steps, path, leela_path,
prev_swa, output_nets)

self.session.run(tf.assign(self.output_nets, output_nets + 1))
save_path = self.saver.save(self.session, path, global_step=steps)
print("Model saved in file: {}".format(save_path))


def save_leelaz_weights(self, filename):
with open(filename, "w") as file:
# Version tag
Expand Down Expand Up @@ -380,3 +433,39 @@ def construct_net(self, planes):
h_fc3 = tf.nn.tanh(tf.add(tf.matmul(h_fc2, W_fc3), b_fc3))

return h_fc1, h_fc3

def save_swa_network(self, steps, path, leela_path, prev_swa, output_nets):
n = output_nets // self.swa_c
if self.swa_max_n != None:
n = min(n, self.swa_max_n)

swa_path = path + "-swa-" + str(n + 1) + "-" + str(steps) + ".txt"

if not os.path.isfile(prev_swa):
# Average of one network is the network itself
copyfile(leela_path, swa_path)
else:
if self.swa_recalc_bn:
swa([prev_swa, leela_path], 'swa_temp.txt', weights=[n, 1])
else:
swa([prev_swa, leela_path], swa_path, weights=[n, 1])

if n > 0 and self.swa_recalc_bn:
# Load SWA weights for batch norm recalculation
version, blocks, channels, weights = read_weights('swa_temp.txt')
self.replace_weights(weights)

print("Recalculating SWA batch normalization")
for _ in range(200):
self.session.run(
[self.policy_loss, self.mse_loss, self.reg_term, self.next_batch],
feed_dict={self.training: True, self.handle: self.train_handle})

self.save_leelaz_weights(swa_path)

# Now load again the training weights
version, blocks, channels, weights = read_weights(leela_path)
self.replace_weights(weights)

self.session.run(tf.assign(self.prev_swa, swa_path))
print("Wrote averaged network to {}".format(swa_path))

0 comments on commit 94f48bb

Please sign in to comment.