forked from pytorch/ELF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmcts_prediction.py
92 lines (75 loc) · 2.99 KB
/
mcts_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from torch.autograd import Variable
import elf.logging as logging
from elf.options import auto_import_options, PyOptionSpec
from rlpytorch.trainer.timer import RLTimer
_logger_factory = logging.IndexedLoggerFactory(
lambda name: logging.stdout_color_mt(name))
class MCTSPrediction(object):
@classmethod
def get_option_spec(cls):
spec = PyOptionSpec()
spec.addBoolOption(
'backprop',
'Whether to backprop the total loss',
True)
return spec
@auto_import_options
def __init__(self, option_map):
self.policy_loss = nn.KLDivLoss().cuda()
self.value_loss = nn.MSELoss().cuda()
self.logger = _logger_factory.makeLogger(
'elfgames.go.MCTSPrediction-', '')
self.timer = RLTimer()
def update(self, mi, batch, stats, use_cooldown=False, cooldown_count=0):
''' Update given batch '''
self.timer.restart()
if use_cooldown:
if cooldown_count == 0:
mi['model'].prepare_cooldown()
self.timer.record('prepare_cooldown')
# Current timestep.
state_curr = mi['model'](batch)
self.timer.record('forward')
if use_cooldown:
self.logger.debug(self.timer.print(1))
return dict(backprop=False)
targets = batch["mcts_scores"]
logpi = state_curr["logpi"]
pi = state_curr["pi"]
# backward.
# loss = self.policy_loss(logpi, Variable(targets)) * logpi.size(1)
loss = - (logpi * Variable(targets)
).sum(dim=1).mean() # * logpi.size(1)
stats["loss"].feed(float(loss))
total_policy_loss = loss
entropy = (logpi * pi).sum() * -1 / logpi.size(0)
stats["entropy"].feed(float(entropy))
stats["blackwin"].feed(
float((batch["winner"] > 0.0).float().sum()) /
batch["winner"].size(0))
total_value_loss = None
if "V" in state_curr and "winner" in batch:
total_value_loss = self.value_loss(
state_curr["V"].squeeze(), Variable(batch["winner"]))
stats["total_policy_loss"].feed(float(total_policy_loss))
if total_value_loss is not None:
stats["total_value_loss"].feed(float(total_value_loss))
total_loss = total_policy_loss + total_value_loss
else:
total_loss = total_policy_loss
stats["total_loss"].feed(float(total_loss))
self.timer.record('feed_stats')
if self.options.backprop:
total_loss.backward()
self.timer.record('backward')
self.logger.debug(self.timer.print(1))
return dict(backprop=True)
else:
self.logger.debug(self.timer.print(1))
return dict(backprop=False)