forked from pytorch/ELF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalue_matcher.py
74 lines (56 loc) · 2.02 KB
/
value_matcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from torch.autograd import Variable
from elf.options import auto_import_options, PyOptionSpec
from .utils import average_norm_clip
class ValueMatcher(object):
@classmethod
def get_option_spec(cls):
spec = PyOptionSpec()
spec.addFloatOption(
'grad_clip_norm',
'gradient norm clipping',
0.0)
spec.addStrOption(
'value_node',
'name of the value node',
'V')
return spec
@auto_import_options
def __init__(self, option_map):
"""Initialization of value matcher.
Initialize value loss to be ``nn.SmoothL1Loss``.
"""
self.value_loss = nn.SmoothL1Loss().cuda()
def _reg_backward(self, v):
''' Register the backward hook. Clip the gradient if necessary.'''
grad_clip_norm = self.options.grad_clip_norm
if grad_clip_norm > 1e-20:
def bw_hook(grad_in):
grad = grad_in.clone()
if grad_clip_norm is not None:
average_norm_clip(grad, grad_clip_norm)
return grad
v.register_hook(bw_hook)
def feed(self, batch, stats):
"""
One iteration of value match.
nabla_w Loss(V - target)
Keys in a batch:
``V`` (variable): value
``target`` (tensor): target value.
Inputs that are of type Variable can backpropagate.
Feed to stats: predicted value and value error
Returns:
value_err
"""
V = batch[self.options.value_node]
value_err = self.value_loss(V, Variable(batch["target"]))
self._reg_backward(V)
stats["predicted_" + self.options.value_node].feed(V.data[0])
stats[self.options.value_node + "_err"].feed(value_err.data[0])
return value_err