Skip to content

Commit

Permalink
Merge pull request pytorch#78 from jma127/master
Browse files Browse the repository at this point in the history
Fix half precision model.
  • Loading branch information
jma127 authored Aug 7, 2018
2 parents ec83be4 + adba009 commit 2c37625
Showing 1 changed file with 32 additions and 37 deletions.
69 changes: 32 additions & 37 deletions src_py/rlpytorch/utils/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,47 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from rlpytorch import Model


class tofp16(nn.Module):
def __init__(self):
super(tofp16, self).__init__()

def forward(self, input):
return input.half()


def copy_in_params(net, params):
net_params = list(net.parameters())
for i in range(len(params)):
net_params[i].data.copy_(params[i].data)
def apply_nonrecursive(module, fn):
"""Applies a given function only to parameters and buffers of a module.
Adapted from torch.nn.Module._apply.
"""
for param in module._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param._grad is not None:
param._grad.data = fn(param._grad.data)

def set_grad(params, params_with_grad):
for key, buf in module._buffers.items():
if buf is not None:
module._buffers[key] = fn(buf)

for param, param_w_grad in zip(params, params_with_grad):
if param.grad is None:
param.grad = torch.nn.Parameter(
param.data.new().resize_(*param.data.size()))
param.grad.data.copy_(param_w_grad.grad.data)


def BN_convert_float(module):
'''
BatchNorm layers to have parameters in single precision.
Find all layers and convert them back to float. This can't
be done with built in .apply as that function will apply
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
BN_convert_float(child)
return module
def convert_fp16_if(module, condition):
"""Nonrecursively converts a module's parameters and buffers to fp16
if a given condition is met.
"""
if condition(module):
apply_nonrecursive(
module, lambda t: t.half() if t.is_floating_point() else t)


class FP16Model(Model):
def __init__(self, option_map, params, network):
def __init__(self, option_map, params, model):
super().__init__(option_map, params)
self.seq = nn.Sequential(tofp16(), BN_convert_float(network.half()))

def forward(self, s):
return self.seq(s)
def should_convert_to_fp16(module):
return not isinstance(
module, torch.nn.modules.batchnorm._BatchNorm)

self.fp16_model = convert_fp16_if(
model.float(), should_convert_to_fp16)

def forward(self, input):
fp16_input = input.half()
return self.fp16_model(fp16_input)

0 comments on commit 2c37625

Please sign in to comment.